-
Notifications
You must be signed in to change notification settings - Fork 96
/
sqlalchemy.py
289 lines (210 loc) · 8.65 KB
/
sqlalchemy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
""" SQLAlchemy support. """
from __future__ import absolute_import
import datetime
from types import GeneratorType
import decimal
from sqlalchemy import func
# from sqlalchemy.orm.interfaces import MANYTOONE
from sqlalchemy.orm.collections import InstrumentedList
from sqlalchemy.sql.type_api import TypeDecorator
try:
from sqlalchemy.orm.relationships import RelationshipProperty
except ImportError:
from sqlalchemy.orm.properties import RelationshipProperty
from sqlalchemy.types import (
BIGINT, BOOLEAN, BigInteger, Boolean, CHAR, DATE, DATETIME, DECIMAL, Date,
DateTime, FLOAT, Float, INT, INTEGER, Integer, NCHAR, NVARCHAR, NUMERIC,
Numeric, SMALLINT, SmallInteger, String, TEXT, TIME, Text, Time, Unicode,
UnicodeText, VARCHAR, Enum)
from .. import mix_types as t
from ..main import (
SKIP_VALUE, LOGGER, TypeMixer as BaseTypeMixer, GenFactory as BaseFactory,
Mixer as BaseMixer, partial, faker)
class GenFactory(BaseFactory):
""" Map a sqlalchemy classes to simple types. """
types = {
(String, VARCHAR, Unicode, NVARCHAR, NCHAR, CHAR): str,
(Text, UnicodeText, TEXT): t.Text,
(Boolean, BOOLEAN): bool,
(Date, DATE): datetime.date,
(DateTime, DATETIME): datetime.datetime,
(Time, TIME): datetime.time,
(DECIMAL, Numeric, NUMERIC): decimal.Decimal,
(Float, FLOAT): float,
(Integer, INTEGER, INT): int,
(BigInteger, BIGINT): t.BigInteger,
(SmallInteger, SMALLINT): t.SmallInteger,
}
class TypeMixer(BaseTypeMixer):
""" TypeMixer for SQLAlchemy. """
factory = GenFactory
def __init__(self, cls, **params):
""" Init TypeMixer and save the mapper. """
super(TypeMixer, self).__init__(cls, **params)
self.mapper = self.__scheme._sa_class_manager.mapper
def postprocess(self, target, postprocess_values):
""" Fill postprocess values. """
mixed = []
for name, deffered in postprocess_values:
value = deffered.value
if isinstance(value, GeneratorType):
value = next(value)
if isinstance(value, t.Mix):
mixed.append((name, value))
continue
if isinstance(getattr(target, name), InstrumentedList) and not isinstance(value, list):
value = [value]
setattr(target, name, value)
for name, mix in mixed:
setattr(target, name, mix & target)
if self.__mixer:
target = self.__mixer.postprocess(target)
return target
@staticmethod
def get_default(field):
""" Get default value from field.
:return value: A default value or NO_VALUE
"""
column = field.scheme
if isinstance(column, RelationshipProperty):
column = column.local_remote_pairs[0][0]
if not column.default:
return SKIP_VALUE
if column.default.is_callable:
return column.default.arg(None)
return getattr(column.default, 'arg', SKIP_VALUE)
def gen_select(self, field_name, select):
""" Select exists value from database.
:param field_name: Name of field for generation.
:return : None or (name, value) for later use
"""
if not self.__mixer or not self.__mixer.params.get('session'):
return field_name, SKIP_VALUE
relation = self.mapper.get_property(field_name)
session = self.__mixer.params.get('session')
value = session.query(
relation.mapper.class_
).filter(*select.choices).order_by(func.random()).first()
return self.get_value(field_name, value)
@staticmethod
def is_unique(field):
""" Return True is field's value should be a unique.
:return bool:
"""
scheme = field.scheme
if isinstance(scheme, RelationshipProperty):
scheme = scheme.local_remote_pairs[0][0]
return scheme.unique
@staticmethod
def is_required(field):
""" Return True is field's value should be defined.
:return bool:
"""
column = field.scheme
if isinstance(column, RelationshipProperty):
return False
if field.params:
return True
# According to the SQLAlchemy docs, autoincrement "only has an effect for columns which are
# Integer derived (i.e. INT, SMALLINT, BIGINT) [and] Part of the primary key [...]".
return not column.nullable and not (column.autoincrement and column.primary_key and
isinstance(column.type, Integer))
def get_value(self, field_name, field_value):
""" Get `value` as `field_name`.
:return : None or (name, value) for later use
"""
field = self.__fields.get(field_name)
if field and isinstance(field.scheme, RelationshipProperty):
return field_name, t._Deffered(field_value, field.scheme)
return super(TypeMixer, self).get_value(field_name, field_value)
def make_fabric(self, column, field_name=None, fake=False, kwargs=None): # noqa
""" Make values fabric for column.
:param column: SqlAlchemy column
:param field_name: Field name
:param fake: Force fake data
:return function:
"""
kwargs = {} if kwargs is None else kwargs
if isinstance(column, RelationshipProperty):
return partial(type(self)(
column.mapper.class_, mixer=self.__mixer, fake=self.__fake, factory=self.__factory
).blend, **kwargs)
ftype = type(column.type)
# augmented types created with TypeDecorator
# don't directly inherit from the base types
if TypeDecorator in ftype.__bases__:
ftype = ftype.impl
stype = self.__factory.cls_to_simple(ftype)
if stype is str:
fab = super(TypeMixer, self).make_fabric(
stype, field_name=field_name, fake=fake, kwargs=kwargs)
return lambda: fab()[:column.type.length]
if ftype is Enum:
return partial(faker.random_element, column.type.enums)
return super(TypeMixer, self).make_fabric(
stype, field_name=field_name, fake=fake, kwargs=kwargs)
def guard(self, *args, **kwargs):
""" Look objects in database.
:returns: A finded object or False
"""
try:
session = self.__mixer.params.get('session')
assert session
except (AttributeError, AssertionError):
raise ValueError('Cannot make request to DB.')
qs = session.query(self.mapper).filter(*args, **kwargs)
count = qs.count()
if count == 1:
return qs.first()
if count:
return qs.all()
return False
def reload(self, obj):
""" Reload object from database. """
try:
session = self.__mixer.params.get('session')
session.expire(obj)
session.refresh(obj)
return obj
except (AttributeError, AssertionError):
raise ValueError('Cannot make request to DB.')
def __load_fields(self):
""" Prepare SQLALchemyTypeMixer.
Select columns and relations for data generation.
"""
mapper = self.__scheme._sa_class_manager.mapper
relations = set()
if hasattr(mapper, 'relationships'):
for rel in mapper.relationships:
relations |= rel.local_columns
yield rel.key, t.Field(rel, rel.key)
for key, column in mapper.columns.items():
if column not in relations:
yield key, t.Field(column, key)
class Mixer(BaseMixer):
""" Integration with SQLAlchemy. """
type_mixer_cls = TypeMixer
def __init__(self, session=None, commit=True, **params):
"""Initialize the SQLAlchemy Mixer.
:param fake: (True) Generate fake data instead of random data.
:param session: SQLAlchemy session. Using for commits.
:param commit: (True) Commit instance to session after creation.
"""
super(Mixer, self).__init__(**params)
self.params['session'] = session
self.params['commit'] = bool(session) and commit
def postprocess(self, target):
""" Save objects in db.
:return value: A generated value
"""
if self.params.get('commit'):
session = self.params.get('session')
if not session:
LOGGER.warn("'commit' set true but session not initialized.")
else:
session.add(target)
session.commit()
return target
# Default mixer
mixer = Mixer()
# pylama:ignore=E1120,E0611