From a57c4312e9000f15619299e91018993e4bf8d39c Mon Sep 17 00:00:00 2001 From: Tim Heap Date: Fri, 24 May 2019 16:57:56 +1000 Subject: [PATCH 1/5] Use specific EnumSelectField for PEP-435 enum columns --- wtforms_sqlalchemy/fields.py | 38 ++++++++++++++++++++++++++++++++++++ wtforms_sqlalchemy/orm.py | 11 ++++++++--- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/wtforms_sqlalchemy/fields.py b/wtforms_sqlalchemy/fields.py index da5cb47..ab2bbb9 100644 --- a/wtforms_sqlalchemy/fields.py +++ b/wtforms_sqlalchemy/fields.py @@ -202,3 +202,41 @@ class QueryCheckboxField(QuerySelectMultipleField): def get_pk_from_identity(obj): key = identity_key(instance=obj)[1] return ':'.join(text_type(x) for x in key) + + +class EnumSelectField(SelectFieldBase): + widget = widgets.Select() + + def __init__(self, *args, enum, members=None, **kwargs): + super().__init__(*args, **kwargs) + self.enum = enum + if members is not None: + self.members = list(members) + else: + self.members = list(enum) + self.members_set = set(self.members) + self.members_map = {member.name: member for member in self.members} + + def to_choice(self, member): + return (member.name, str(member)) + + def iter_choices(self): + for member in self.members: + yield (member.name, str(member), member is self.data) + + def process_data(self, value): + if isinstance(value, self.enum): + self.data = value + else: + self.data = None + + def process_formdata(self, valuelist): + if valuelist: + try: + self.data = self.members_map[valuelist[0]] + except (ValueError, TypeError, KeyError): + raise ValueError(self.gettext('Invalid Choice: could not coerce')) + + def pre_validate(self, form): + if self.data not in self.members: + raise ValueError(self.gettext('Not a valid choice')) diff --git a/wtforms_sqlalchemy/orm.py b/wtforms_sqlalchemy/orm.py index 704abdc..0e2f63d 100644 --- a/wtforms_sqlalchemy/orm.py +++ b/wtforms_sqlalchemy/orm.py @@ -7,7 +7,8 @@ from wtforms import validators, fields as wtforms_fields from wtforms.form import Form -from .fields import QuerySelectField, QuerySelectMultipleField + +from .fields import EnumSelectField, QuerySelectField, QuerySelectMultipleField __all__ = ( 'model_fields', 'model_form', @@ -179,8 +180,12 @@ def conv_DateTime(self, field_args, **extra): @converts('Enum') def conv_Enum(self, column, field_args, **extra): - field_args['choices'] = [(e, e) for e in column.type.enums] - return wtforms_fields.SelectField(**field_args) + if column.type.enum_class is not None: + field_args['enum'] = column.type.enum_class + return EnumSelectField(**field_args) + else: + field_args['choices'] = [(e, e) for e in column.type.enums] + return wtforms_fields.SelectField(**field_args) @converts('Integer') # includes BigInteger and SmallInteger def handle_integer_types(self, column, field_args, **extra): From 0cde09c0aab9f8c519fcda94201fc343295ef2bf Mon Sep 17 00:00:00 2001 From: Tim Heap Date: Fri, 24 May 2019 16:59:05 +1000 Subject: [PATCH 2/5] Use SQLAlchemy's Column.doc for WTForm's Field.description --- wtforms_sqlalchemy/orm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/wtforms_sqlalchemy/orm.py b/wtforms_sqlalchemy/orm.py index 0e2f63d..15aabd5 100644 --- a/wtforms_sqlalchemy/orm.py +++ b/wtforms_sqlalchemy/orm.py @@ -83,6 +83,7 @@ def convert(self, model, mapper, prop, field_args, db_session=None): 'filters': [], 'default': None, } + kwargs['description'] = prop.doc if field_args: kwargs.update(field_args) From 9a7afac8f4c85f39c4ef9bc052be7879af7afbf9 Mon Sep 17 00:00:00 2001 From: Tim Heap Date: Fri, 24 May 2019 16:59:34 +1000 Subject: [PATCH 3/5] Allow customising the QuerySelectField blank label --- wtforms_sqlalchemy/fields.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/wtforms_sqlalchemy/fields.py b/wtforms_sqlalchemy/fields.py index ab2bbb9..76ceced 100644 --- a/wtforms_sqlalchemy/fields.py +++ b/wtforms_sqlalchemy/fields.py @@ -57,7 +57,7 @@ class QuerySelectField(SelectFieldBase): def __init__(self, label=None, validators=None, query_factory=None, get_pk=None, get_label=None, allow_blank=False, - blank_text='', **kwargs): + blank_value='__None', blank_text='Select...', **kwargs): super(QuerySelectField, self).__init__(label, validators, **kwargs) self.query_factory = query_factory @@ -76,6 +76,7 @@ def __init__(self, label=None, validators=None, query_factory=None, self.get_label = get_label self.allow_blank = allow_blank + self.blank_value = blank_value self.blank_text = blank_text self.query = None self._object_list = None @@ -105,15 +106,15 @@ def _get_object_list(self): return self._object_list def iter_choices(self): - if self.allow_blank: - yield ('__None', self.blank_text, self.data is None) + if self.allow_blank or self.data is None: + yield (self.blank_value, self.blank_text, self.data is None) for pk, obj in self._get_object_list(): yield (pk, self.get_label(obj), obj == self.data) def process_formdata(self, valuelist): if valuelist: - if self.allow_blank and valuelist[0] == '__None': + if self.allow_blank and valuelist[0] == self.blank_value: self.data = None else: self._data = None From bf55fa540d23173a89a152f341247dca47064c20 Mon Sep 17 00:00:00 2001 From: Tim Heap Date: Fri, 24 May 2019 17:00:16 +1000 Subject: [PATCH 4/5] Do not assume not-nullable means required Not-null Boolean columns fail this requirement, for example. will not allow the page to be submitted until it is checked, despite an unchecked/falsey checkbox being valid data at the model level. --- wtforms_sqlalchemy/orm.py | 47 ++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/wtforms_sqlalchemy/orm.py b/wtforms_sqlalchemy/orm.py index 15aabd5..7db9c4a 100644 --- a/wtforms_sqlalchemy/orm.py +++ b/wtforms_sqlalchemy/orm.py @@ -114,11 +114,6 @@ def convert(self, model, mapper, prop, field_args, db_session=None): kwargs['default'] = default - if column.nullable: - kwargs['validators'].append(validators.Optional()) - else: - kwargs['validators'].append(validators.Required()) - converter = self.get_converter(column) else: # We have a property with a direction. @@ -152,35 +147,47 @@ class ModelConverter(ModelConverterBase): def __init__(self, extra_converters=None, use_mro=True): super(ModelConverter, self).__init__(extra_converters, use_mro=use_mro) + @classmethod + def _nullable_required(cls, column, field_args, **extra): + if column.nullable: + field_args['validators'].append(validators.Optional()) + else: + field_args['validators'].append(validators.Required()) + @classmethod def _string_common(cls, column, field_args, **extra): if isinstance(column.type.length, int) and column.type.length: field_args['validators'].append(validators.Length(max=column.type.length)) @converts('String') # includes Unicode - def conv_String(self, field_args, **extra): - self._string_common(field_args=field_args, **extra) + def conv_String(self, column, field_args, **extra): + self._string_common(column=column, field_args=field_args, **extra) + self._nullable_required(column=column, field_args=field_args, **extra) return wtforms_fields.StringField(**field_args) @converts('Text', 'LargeBinary', 'Binary') # includes UnicodeText - def conv_Text(self, field_args, **extra): - self._string_common(field_args=field_args, **extra) + def conv_Text(self, column, field_args, **extra): + self._string_common(column=column, field_args=field_args, **extra) + self._nullable_required(column=column, field_args=field_args, **extra) return wtforms_fields.TextAreaField(**field_args) @converts('Boolean', 'dialects.mssql.base.BIT') - def conv_Boolean(self, field_args, **extra): + def conv_Boolean(self, column, field_args, **extra): return wtforms_fields.BooleanField(**field_args) @converts('Date') - def conv_Date(self, field_args, **extra): + def conv_Date(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) return wtforms_fields.DateField(**field_args) @converts('DateTime') - def conv_DateTime(self, field_args, **extra): + def conv_DateTime(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) return wtforms_fields.DateTimeField(**field_args) @converts('Enum') def conv_Enum(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) if column.type.enum_class is not None: field_args['enum'] = column.type.enum_class return EnumSelectField(**field_args) @@ -202,34 +209,38 @@ def handle_decimal_types(self, column, field_args, **extra): return wtforms_fields.DecimalField(**field_args) @converts('dialects.mysql.types.YEAR', 'dialects.mysql.base.YEAR') - def conv_MSYear(self, field_args, **extra): + def conv_MSYear(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) field_args['validators'].append(validators.NumberRange(min=1901, max=2155)) return wtforms_fields.StringField(**field_args) @converts('dialects.postgresql.base.INET') - def conv_PGInet(self, field_args, **extra): + def conv_PGInet(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) field_args.setdefault('label', 'IP Address') field_args['validators'].append(validators.IPAddress()) return wtforms_fields.StringField(**field_args) @converts('dialects.postgresql.base.MACADDR') - def conv_PGMacaddr(self, field_args, **extra): + def conv_PGMacaddr(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) field_args.setdefault('label', 'MAC Address') field_args['validators'].append(validators.MacAddress()) return wtforms_fields.StringField(**field_args) @converts('dialects.postgresql.base.UUID') - def conv_PGUuid(self, field_args, **extra): + def conv_PGUuid(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) field_args.setdefault('label', 'UUID') field_args['validators'].append(validators.UUID()) return wtforms_fields.StringField(**field_args) @converts('MANYTOONE') - def conv_ManyToOne(self, field_args, **extra): + def conv_ManyToOne(self, column, field_args, **extra): return QuerySelectField(**field_args) @converts('MANYTOMANY', 'ONETOMANY') - def conv_ManyToMany(self, field_args, **extra): + def conv_ManyToMany(self, column, field_args, **extra): return QuerySelectMultipleField(**field_args) From 30f8977d3b755c9c7f5dfab5ed873762e40bc355 Mon Sep 17 00:00:00 2001 From: Tim Heap Date: Fri, 24 May 2019 17:03:00 +1000 Subject: [PATCH 5/5] Always return model_fields fields in the order requested If specific fields are requested using the 'only' argument, the fields are now returned in that order. This makes use of ordered dictionary keys implemented in Python 3.7. Older versions of Python will continue to scramble the dictionary order as always. --- wtforms_sqlalchemy/orm.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/wtforms_sqlalchemy/orm.py b/wtforms_sqlalchemy/orm.py index 7db9c4a..d3d5593 100644 --- a/wtforms_sqlalchemy/orm.py +++ b/wtforms_sqlalchemy/orm.py @@ -255,7 +255,7 @@ def model_fields(model, db_session=None, only=None, exclude=None, mapper = model._sa_class_manager.mapper converter = converter or ModelConverter() field_args = field_args or {} - properties = [] + properties = {} for prop in mapper.iterate_properties: if getattr(prop, 'columns', None): @@ -264,16 +264,17 @@ def model_fields(model, db_session=None, only=None, exclude=None, elif exclude_pk and prop.columns[0].primary_key: continue - properties.append((prop.key, prop)) + properties[prop.key] = prop - #((p.key, p) for p in mapper.iterate_properties) if only: - properties = (x for x in properties if x[0] in only) + order = list(only) + properties = {key: properties[key] for key in only} elif exclude: - properties = (x for x in properties if x[0] not in exclude) + properties = {key: prop for key, prop in properties if key not in exclude} + order = list(properties.keys()) field_dict = {} - for name, prop in properties: + for name, prop in properties.items(): field = converter.convert( model, mapper, prop, field_args.get(name), db_session @@ -281,7 +282,7 @@ def model_fields(model, db_session=None, only=None, exclude=None, if field is not None: field_dict[name] = field - return field_dict + return {key: field_dict[key] for key in order} def model_form(model, db_session=None, base_class=Form, only=None,