diff --git a/wtforms_sqlalchemy/orm.py b/wtforms_sqlalchemy/orm.py index f1fefb6..5a68fd3 100644 --- a/wtforms_sqlalchemy/orm.py +++ b/wtforms_sqlalchemy/orm.py @@ -115,11 +115,6 @@ def convert(self, model, mapper, prop, field_args, db_session=None): kwargs['default'] = default - if column.nullable: - kwargs['validators'].append(validators.Optional()) - else: - kwargs['validators'].append(validators.Required()) - converter = self.get_converter(column) else: # We have a property with a direction. @@ -153,40 +148,53 @@ class ModelConverter(ModelConverterBase): def __init__(self, extra_converters=None, use_mro=True): super(ModelConverter, self).__init__(extra_converters, use_mro=use_mro) + @classmethod + def _nullable_required(cls, column, field_args, **extra): + if column.nullable: + field_args['validators'].append(validators.Optional()) + else: + field_args['validators'].append(validators.Required()) + @classmethod def _string_common(cls, column, field_args, **extra): if isinstance(column.type.length, int) and column.type.length: field_args['validators'].append(validators.Length(max=column.type.length)) @converts('String') # includes Unicode - def conv_String(self, field_args, **extra): - self._string_common(field_args=field_args, **extra) + def conv_String(self, column, field_args, **extra): + self._string_common(column=column, field_args=field_args, **extra) + self._nullable_required(column=column, field_args=field_args, **extra) return wtforms_fields.StringField(**field_args) @converts('Text', 'LargeBinary', 'Binary') # includes UnicodeText - def conv_Text(self, field_args, **extra): - self._string_common(field_args=field_args, **extra) + def conv_Text(self, column, field_args, **extra): + self._string_common(column=column, field_args=field_args, **extra) + self._nullable_required(column=column, field_args=field_args, **extra) return wtforms_fields.TextAreaField(**field_args) @converts('Boolean', 'dialects.mssql.base.BIT') - def conv_Boolean(self, field_args, **extra): + def conv_Boolean(self, column, field_args, **extra): return wtforms_fields.BooleanField(**field_args) @converts('Date') - def conv_Date(self, field_args, **extra): + def conv_Date(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) return wtforms_fields.DateField(**field_args) @converts('DateTime') - def conv_DateTime(self, field_args, **extra): + def conv_DateTime(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) return wtforms_fields.DateTimeField(**field_args) @converts('Enum') def conv_Enum(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) field_args['choices'] = [(e, e) for e in column.type.enums] return wtforms_fields.SelectField(**field_args) @converts('Integer') # includes BigInteger and SmallInteger def handle_integer_types(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) unsigned = getattr(column.type, 'unsigned', False) if unsigned: field_args['validators'].append(validators.NumberRange(min=0)) @@ -194,39 +202,45 @@ def handle_integer_types(self, column, field_args, **extra): @converts('Numeric') # includes DECIMAL, Float/FLOAT, REAL, and DOUBLE def handle_decimal_types(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) # override default decimal places limit, use database defaults instead field_args.setdefault('places', None) return wtforms_fields.DecimalField(**field_args) @converts('dialects.mysql.types.YEAR', 'dialects.mysql.base.YEAR') - def conv_MSYear(self, field_args, **extra): + def conv_MSYear(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) field_args['validators'].append(validators.NumberRange(min=1901, max=2155)) return wtforms_fields.StringField(**field_args) @converts('dialects.postgresql.base.INET') - def conv_PGInet(self, field_args, **extra): + def conv_PGInet(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) field_args.setdefault('label', 'IP Address') field_args['validators'].append(validators.IPAddress()) return wtforms_fields.StringField(**field_args) @converts('dialects.postgresql.base.MACADDR') - def conv_PGMacaddr(self, field_args, **extra): + def conv_PGMacaddr(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) field_args.setdefault('label', 'MAC Address') field_args['validators'].append(validators.MacAddress()) return wtforms_fields.StringField(**field_args) @converts('dialects.postgresql.base.UUID') - def conv_PGUuid(self, field_args, **extra): + def conv_PGUuid(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) field_args.setdefault('label', 'UUID') field_args['validators'].append(validators.UUID()) return wtforms_fields.StringField(**field_args) @converts('MANYTOONE') - def conv_ManyToOne(self, field_args, **extra): + def conv_ManyToOne(self, column, field_args, **extra): return QuerySelectField(**field_args) @converts('MANYTOMANY', 'ONETOMANY') - def conv_ManyToMany(self, field_args, **extra): + def conv_ManyToMany(self, column, field_args, **extra): + self._nullable_required(column=column, field_args=field_args, **extra) return QuerySelectMultipleField(**field_args)