diff --git a/wtforms_sqlalchemy/fields.py b/wtforms_sqlalchemy/fields.py index da5cb47..76ceced 100644 --- a/wtforms_sqlalchemy/fields.py +++ b/wtforms_sqlalchemy/fields.py @@ -57,7 +57,7 @@ class QuerySelectField(SelectFieldBase): def __init__(self, label=None, validators=None, query_factory=None, get_pk=None, get_label=None, allow_blank=False, - blank_text='', **kwargs): + blank_value='__None', blank_text='Select...', **kwargs): super(QuerySelectField, self).__init__(label, validators, **kwargs) self.query_factory = query_factory @@ -76,6 +76,7 @@ def __init__(self, label=None, validators=None, query_factory=None, self.get_label = get_label self.allow_blank = allow_blank + self.blank_value = blank_value self.blank_text = blank_text self.query = None self._object_list = None @@ -105,15 +106,15 @@ def _get_object_list(self): return self._object_list def iter_choices(self): - if self.allow_blank: - yield ('__None', self.blank_text, self.data is None) + if self.allow_blank or self.data is None: + yield (self.blank_value, self.blank_text, self.data is None) for pk, obj in self._get_object_list(): yield (pk, self.get_label(obj), obj == self.data) def process_formdata(self, valuelist): if valuelist: - if self.allow_blank and valuelist[0] == '__None': + if self.allow_blank and valuelist[0] == self.blank_value: self.data = None else: self._data = None @@ -202,3 +203,41 @@ class QueryCheckboxField(QuerySelectMultipleField): def get_pk_from_identity(obj): key = identity_key(instance=obj)[1] return ':'.join(text_type(x) for x in key) + + +class EnumSelectField(SelectFieldBase): + widget = widgets.Select() + + def __init__(self, *args, enum, members=None, **kwargs): + super().__init__(*args, **kwargs) + self.enum = enum + if members is not None: + self.members = list(members) + else: + self.members = list(enum) + self.members_set = set(self.members) + self.members_map = {member.name: member for member in self.members} + + def to_choice(self, member): + return (member.name, str(member)) + + def iter_choices(self): + for member in self.members: + yield (member.name, str(member), member is self.data) + + def process_data(self, value): + if isinstance(value, self.enum): + self.data = value + else: + self.data = None + + def process_formdata(self, valuelist): + if valuelist: + try: + self.data = self.members_map[valuelist[0]] + except (ValueError, TypeError, KeyError): + raise ValueError(self.gettext('Invalid Choice: could not coerce')) + + def pre_validate(self, form): + if self.data not in self.members: + raise ValueError(self.gettext('Not a valid choice')) diff --git a/wtforms_sqlalchemy/orm.py b/wtforms_sqlalchemy/orm.py index 704abdc..d3d5593 100644 --- a/wtforms_sqlalchemy/orm.py +++ b/wtforms_sqlalchemy/orm.py @@ -7,7 +7,8 @@ from wtforms import validators, fields as wtforms_fields from wtforms.form import Form -from .fields import QuerySelectField, QuerySelectMultipleField + +from .fields import EnumSelectField, QuerySelectField, QuerySelectMultipleField __all__ = ( 'model_fields', 'model_form', @@ -82,6 +83,7 @@ def convert(self, model, mapper, prop, field_args, db_session=None): 'filters': [], 'default': None, } + kwargs['description'] = prop.doc if field_args: kwargs.update(field_args) @@ -112,11 +114,6 @@ def convert(self, model, mapper, prop, field_args, db_session=None): kwargs['default'] = default - if column.nullable: - kwargs['validators'].append(validators.Optional()) - else: - kwargs['validators'].append(validators.Required()) - converter = self.get_converter(column) else: # We have a property with a direction. @@ -150,37 +147,53 @@ class ModelConverter(ModelConverterBase): def __init__(self, extra_converters=None, use_mro=True): super(ModelConverter, self).__init__(extra_converters, use_mro=use_mro) + @classmethod + def _nullable_required(cls, column, field_args, **extra): + if column.nullable: + field_args['validators'].append(validators.Optional()) + else: + field_args['validators'].append(validators.Required()) + @classmethod def _string_common(cls, column, field_args, **extra): if isinstance(column.type.length, int) and column.type.length: field_args['validators'].append(validators.Length(max=column.type.length)) @converts('String') # includes Unicode - def conv_String(self, field_args, **extra): - self._string_common(field_args=field_args, **extra) + def conv_String(self, column, field_args, **extra): + self._string_common(column=column, field_args=field_args, **extra) + self._nullable_required(column=column, field_args=field_args, **extra) return wtforms_fields.StringField(**field_args) @converts('Text', 'LargeBinary', 'Binary') # includes UnicodeText - def conv_Text(self, field_args, **extra): - self._string_common(field_args=field_args, **extra) + def conv_Text(self, column, field_args, **extra): + self._string_common(column=column, field_args=field_args, **extra) + self._nullable_required(column=column, field_args=field_args, **extra) return wtforms_fields.TextAreaField(**field_args) @converts('Boolean', 'dialects.mssql.base.BIT') - def conv_Boolean(self, field_args, **extra): + def conv_Boolean(self, column, field_args, **extra): return wtforms_fields.BooleanField(**field_args) @converts('Date') - def conv_Date(self, field_args, **extra): + def conv_Date(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) return wtforms_fields.DateField(**field_args) @converts('DateTime') - def conv_DateTime(self, field_args, **extra): + def conv_DateTime(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) return wtforms_fields.DateTimeField(**field_args) @converts('Enum') def conv_Enum(self, column, field_args, **extra): - field_args['choices'] = [(e, e) for e in column.type.enums] - return wtforms_fields.SelectField(**field_args) + self._nullable_required(column=column, field_args=field_args, **extra) + if column.type.enum_class is not None: + field_args['enum'] = column.type.enum_class + return EnumSelectField(**field_args) + else: + field_args['choices'] = [(e, e) for e in column.type.enums] + return wtforms_fields.SelectField(**field_args) @converts('Integer') # includes BigInteger and SmallInteger def handle_integer_types(self, column, field_args, **extra): @@ -196,34 +209,38 @@ def handle_decimal_types(self, column, field_args, **extra): return wtforms_fields.DecimalField(**field_args) @converts('dialects.mysql.types.YEAR', 'dialects.mysql.base.YEAR') - def conv_MSYear(self, field_args, **extra): + def conv_MSYear(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) field_args['validators'].append(validators.NumberRange(min=1901, max=2155)) return wtforms_fields.StringField(**field_args) @converts('dialects.postgresql.base.INET') - def conv_PGInet(self, field_args, **extra): + def conv_PGInet(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) field_args.setdefault('label', 'IP Address') field_args['validators'].append(validators.IPAddress()) return wtforms_fields.StringField(**field_args) @converts('dialects.postgresql.base.MACADDR') - def conv_PGMacaddr(self, field_args, **extra): + def conv_PGMacaddr(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) field_args.setdefault('label', 'MAC Address') field_args['validators'].append(validators.MacAddress()) return wtforms_fields.StringField(**field_args) @converts('dialects.postgresql.base.UUID') - def conv_PGUuid(self, field_args, **extra): + def conv_PGUuid(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) field_args.setdefault('label', 'UUID') field_args['validators'].append(validators.UUID()) return wtforms_fields.StringField(**field_args) @converts('MANYTOONE') - def conv_ManyToOne(self, field_args, **extra): + def conv_ManyToOne(self, column, field_args, **extra): return QuerySelectField(**field_args) @converts('MANYTOMANY', 'ONETOMANY') - def conv_ManyToMany(self, field_args, **extra): + def conv_ManyToMany(self, column, field_args, **extra): return QuerySelectMultipleField(**field_args) @@ -238,7 +255,7 @@ def model_fields(model, db_session=None, only=None, exclude=None, mapper = model._sa_class_manager.mapper converter = converter or ModelConverter() field_args = field_args or {} - properties = [] + properties = {} for prop in mapper.iterate_properties: if getattr(prop, 'columns', None): @@ -247,16 +264,17 @@ def model_fields(model, db_session=None, only=None, exclude=None, elif exclude_pk and prop.columns[0].primary_key: continue - properties.append((prop.key, prop)) + properties[prop.key] = prop - #((p.key, p) for p in mapper.iterate_properties) if only: - properties = (x for x in properties if x[0] in only) + order = list(only) + properties = {key: properties[key] for key in only} elif exclude: - properties = (x for x in properties if x[0] not in exclude) + properties = {key: prop for key, prop in properties if key not in exclude} + order = list(properties.keys()) field_dict = {} - for name, prop in properties: + for name, prop in properties.items(): field = converter.convert( model, mapper, prop, field_args.get(name), db_session @@ -264,7 +282,7 @@ def model_fields(model, db_session=None, only=None, exclude=None, if field is not None: field_dict[name] = field - return field_dict + return {key: field_dict[key] for key in order} def model_form(model, db_session=None, base_class=Form, only=None,