Skip to content

Commit

Permalink
Merge pull request wtforms#9 from timheap/not-nullable-required
Browse files Browse the repository at this point in the history
Do not assume not-nullable means required for Boolean columns
  • Loading branch information
mlenzen committed Feb 6, 2020
2 parents fc32be9 + 3315884 commit 2f879b2
Showing 1 changed file with 32 additions and 18 deletions.
50 changes: 32 additions & 18 deletions wtforms_sqlalchemy/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,6 @@ def convert(self, model, mapper, prop, field_args, db_session=None):

kwargs['default'] = default

if column.nullable:
kwargs['validators'].append(validators.Optional())
else:
kwargs['validators'].append(validators.Required())

converter = self.get_converter(column)
else:
# We have a property with a direction.
Expand Down Expand Up @@ -153,80 +148,99 @@ class ModelConverter(ModelConverterBase):
def __init__(self, extra_converters=None, use_mro=True):
super(ModelConverter, self).__init__(extra_converters, use_mro=use_mro)

@classmethod
def _nullable_required(cls, column, field_args, **extra):
if column.nullable:
field_args['validators'].append(validators.Optional())
else:
field_args['validators'].append(validators.Required())

@classmethod
def _string_common(cls, column, field_args, **extra):
if isinstance(column.type.length, int) and column.type.length:
field_args['validators'].append(validators.Length(max=column.type.length))

@converts('String') # includes Unicode
def conv_String(self, field_args, **extra):
self._string_common(field_args=field_args, **extra)
def conv_String(self, column, field_args, **extra):
self._string_common(column=column, field_args=field_args, **extra)
self._nullable_required(column=column, field_args=field_args, **extra)
return wtforms_fields.StringField(**field_args)

@converts('Text', 'LargeBinary', 'Binary') # includes UnicodeText
def conv_Text(self, field_args, **extra):
self._string_common(field_args=field_args, **extra)
def conv_Text(self, column, field_args, **extra):
self._string_common(column=column, field_args=field_args, **extra)
self._nullable_required(column=column, field_args=field_args, **extra)
return wtforms_fields.TextAreaField(**field_args)

@converts('Boolean', 'dialects.mssql.base.BIT')
def conv_Boolean(self, field_args, **extra):
def conv_Boolean(self, column, field_args, **extra):
return wtforms_fields.BooleanField(**field_args)

@converts('Date')
def conv_Date(self, field_args, **extra):
def conv_Date(self, column, field_args, **extra):
self._nullable_required(column=column, field_args=field_args, **extra)
return wtforms_fields.DateField(**field_args)

@converts('DateTime')
def conv_DateTime(self, field_args, **extra):
def conv_DateTime(self, column, field_args, **extra):
self._nullable_required(column=column, field_args=field_args, **extra)
return wtforms_fields.DateTimeField(**field_args)

@converts('Enum')
def conv_Enum(self, column, field_args, **extra):
self._nullable_required(column=column, field_args=field_args, **extra)
field_args['choices'] = [(e, e) for e in column.type.enums]
return wtforms_fields.SelectField(**field_args)

@converts('Integer') # includes BigInteger and SmallInteger
def handle_integer_types(self, column, field_args, **extra):
self._nullable_required(column=column, field_args=field_args, **extra)
unsigned = getattr(column.type, 'unsigned', False)
if unsigned:
field_args['validators'].append(validators.NumberRange(min=0))
return wtforms_fields.IntegerField(**field_args)

@converts('Numeric') # includes DECIMAL, Float/FLOAT, REAL, and DOUBLE
def handle_decimal_types(self, column, field_args, **extra):
self._nullable_required(column=column, field_args=field_args, **extra)
# override default decimal places limit, use database defaults instead
field_args.setdefault('places', None)
return wtforms_fields.DecimalField(**field_args)

@converts('dialects.mysql.types.YEAR', 'dialects.mysql.base.YEAR')
def conv_MSYear(self, field_args, **extra):
def conv_MSYear(self, column, field_args, **extra):
self._nullable_required(column=column, field_args=field_args, **extra)
field_args['validators'].append(validators.NumberRange(min=1901, max=2155))
return wtforms_fields.StringField(**field_args)

@converts('dialects.postgresql.base.INET')
def conv_PGInet(self, field_args, **extra):
def conv_PGInet(self, column, field_args, **extra):
self._nullable_required(column=column, field_args=field_args, **extra)
field_args.setdefault('label', 'IP Address')
field_args['validators'].append(validators.IPAddress())
return wtforms_fields.StringField(**field_args)

@converts('dialects.postgresql.base.MACADDR')
def conv_PGMacaddr(self, field_args, **extra):
def conv_PGMacaddr(self, column, field_args, **extra):
self._nullable_required(column=column, field_args=field_args, **extra)
field_args.setdefault('label', 'MAC Address')
field_args['validators'].append(validators.MacAddress())
return wtforms_fields.StringField(**field_args)

@converts('dialects.postgresql.base.UUID')
def conv_PGUuid(self, field_args, **extra):
def conv_PGUuid(self, column, field_args, **extra):
self._nullable_required(column=column, field_args=field_args, **extra)
field_args.setdefault('label', 'UUID')
field_args['validators'].append(validators.UUID())
return wtforms_fields.StringField(**field_args)

@converts('MANYTOONE')
def conv_ManyToOne(self, field_args, **extra):
def conv_ManyToOne(self, column, field_args, **extra):
return QuerySelectField(**field_args)

@converts('MANYTOMANY', 'ONETOMANY')
def conv_ManyToMany(self, field_args, **extra):
def conv_ManyToMany(self, column, field_args, **extra):
self._nullable_required(column=column, field_args=field_args, **extra)
return QuerySelectMultipleField(**field_args)


Expand Down

0 comments on commit 2f879b2

Please sign in to comment.