-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
157 lines (120 loc) · 5.08 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""
Convert SQLAlchemy model classes to WTForms form classes.
"""
import typing
import sqlalchemy
import wtforms
FormType = typing.Type[wtforms.Form]
FieldType = typing.Type[wtforms.Field]
FieldKwargs = dict[str, typing.Any]
class ColumnFieldConverter:
"""
A converter from specific SQLAlchemy columns to WTForms fields.
"""
def get_field_type(self, column: sqlalchemy.Column) -> FieldType:
"""
Get the field type
for the columns for which this converter is intended.
"""
return wtforms.Field
def get_field_kwargs(self, column: sqlalchemy.Column) -> FieldKwargs:
"""
Get the keyword arguments to construct fields
for the columns for which this converter is intended.
"""
return {
"label": column.name.replace("_", " ").title(),
"description": getattr(column, "doc", None),
"default": getattr(column, "default", None),
"validators": [
wtforms.validators.Optional()
if column.nullable
else wtforms.validators.InputRequired()
],
}
class StringColumnFieldConverter(ColumnFieldConverter):
def get_field_type(self, column: sqlalchemy.Column) -> FieldType:
return wtforms.StringField
def get_field_kwargs(self, column: sqlalchemy.Column) -> FieldKwargs:
field_kwargs = super().get_field_kwargs(column)
if column.type.length:
field_kwargs["validators"].append(wtforms.validators.Length(max=column.type.length))
return field_kwargs
class IntegerColumnFieldConverter(ColumnFieldConverter):
def get_field_type(self, column: sqlalchemy.Column) -> FieldType:
return wtforms.IntegerField
# NOTE: The WTForms `datetime-local` field
# has an incorrect format value that affects default-value rendering
# and does not support conversion from user timezones to UTC.
class DateTimeColumnFieldConverter(ColumnFieldConverter):
def get_field_type(self, column: sqlalchemy.Column) -> FieldType:
return wtforms.DateTimeLocalField
class DateColumnFieldConverter(ColumnFieldConverter):
def get_field_type(self, column: sqlalchemy.Column) -> FieldType:
return wtforms.DateField
class TimeColumnFieldConverter(ColumnFieldConverter):
def get_field_type(self, column: sqlalchemy.Column) -> FieldType:
return wtforms.TimeField
class BooleanColumnFieldConverter(ColumnFieldConverter):
def get_field_type(self, column: sqlalchemy.Column) -> FieldType:
return wtforms.BooleanField
class EnumColumnFieldConverter(ColumnFieldConverter):
def get_field_type(self, column: sqlalchemy.Column) -> FieldType:
return wtforms.SelectField
def get_field_kwargs(self, column: sqlalchemy.Column) -> FieldKwargs:
field_kwargs = super().get_field_kwargs(column)
field_kwargs["choices"] = [(choice, choice.title()) for choice in column.type.enums]
return field_kwargs
class ModelFormMixin:
"""
A mixin that adds a `get_model_form` class method to the form class,
which returns a class with fields matching the columns of the model.
"""
converters: dict[typing.Type, ColumnFieldConverter] = {
sqlalchemy.types.String: StringColumnFieldConverter(),
sqlalchemy.types.Integer: IntegerColumnFieldConverter(),
sqlalchemy.types.DateTime: DateTimeColumnFieldConverter(),
sqlalchemy.types.Date: DateColumnFieldConverter(),
sqlalchemy.types.Time: TimeColumnFieldConverter(),
sqlalchemy.types.Boolean: BooleanColumnFieldConverter(),
sqlalchemy.types.Enum: EnumColumnFieldConverter(),
}
@classmethod
def get_model_form(
cls,
model,
names: list[str],
column_converters: dict[str, ColumnFieldConverter] = {},
) -> FormType:
"""
Create a WTForms form from an SQLAlchemy model.
"""
class ModelForm(cls):
"""
The form class created from the model.
"""
columns = sqlalchemy.inspect(model).columns
name_to_column_map = {column.name: column for column in columns}
for name in names:
try:
column = name_to_column_map[name]
except KeyError:
raise KeyError(
f"The submitted name `{name}`"
" does not match any column"
f" of the SQLAlchemy model `{model.__name__}`."
)
try:
converter = cls.converters[type(column.type)]
except KeyError:
raise KeyError(
"No converter currently exists"
f" for SQLAlchemy columns of the type `{type(column.type)}`."
)
# Check whether this column has a custom converter.
if name in column_converters:
converter = column_converters[name]
field_type = converter.get_field_type(column)
field_kwargs = converter.get_field_kwargs(column)
setattr(ModelForm, name, field_type(**field_kwargs))
return ModelForm