Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 43 additions & 4 deletions wtforms_sqlalchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class QuerySelectField(SelectFieldBase):

def __init__(self, label=None, validators=None, query_factory=None,
get_pk=None, get_label=None, allow_blank=False,
blank_text='', **kwargs):
blank_value='__None', blank_text='Select...', **kwargs):
super(QuerySelectField, self).__init__(label, validators, **kwargs)
self.query_factory = query_factory

Expand All @@ -76,6 +76,7 @@ def __init__(self, label=None, validators=None, query_factory=None,
self.get_label = get_label

self.allow_blank = allow_blank
self.blank_value = blank_value
self.blank_text = blank_text
self.query = None
self._object_list = None
Expand Down Expand Up @@ -105,15 +106,15 @@ def _get_object_list(self):
return self._object_list

def iter_choices(self):
if self.allow_blank:
yield ('__None', self.blank_text, self.data is None)
if self.allow_blank or self.data is None:
yield (self.blank_value, self.blank_text, self.data is None)

for pk, obj in self._get_object_list():
yield (pk, self.get_label(obj), obj == self.data)

def process_formdata(self, valuelist):
if valuelist:
if self.allow_blank and valuelist[0] == '__None':
if self.allow_blank and valuelist[0] == self.blank_value:
self.data = None
else:
self._data = None
Expand Down Expand Up @@ -202,3 +203,41 @@ class QueryCheckboxField(QuerySelectMultipleField):
def get_pk_from_identity(obj):
key = identity_key(instance=obj)[1]
return ':'.join(text_type(x) for x in key)


class EnumSelectField(SelectFieldBase):
widget = widgets.Select()

def __init__(self, *args, enum, members=None, **kwargs):
super().__init__(*args, **kwargs)
self.enum = enum
if members is not None:
self.members = list(members)
else:
self.members = list(enum)
self.members_set = set(self.members)
self.members_map = {member.name: member for member in self.members}

def to_choice(self, member):
return (member.name, str(member))

def iter_choices(self):
for member in self.members:
yield (member.name, str(member), member is self.data)

def process_data(self, value):
if isinstance(value, self.enum):
self.data = value
else:
self.data = None

def process_formdata(self, valuelist):
if valuelist:
try:
self.data = self.members_map[valuelist[0]]
except (ValueError, TypeError, KeyError):
raise ValueError(self.gettext('Invalid Choice: could not coerce'))

def pre_validate(self, form):
if self.data not in self.members:
raise ValueError(self.gettext('Not a valid choice'))
74 changes: 46 additions & 28 deletions wtforms_sqlalchemy/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from wtforms import validators, fields as wtforms_fields
from wtforms.form import Form
from .fields import QuerySelectField, QuerySelectMultipleField

from .fields import EnumSelectField, QuerySelectField, QuerySelectMultipleField

__all__ = (
'model_fields', 'model_form',
Expand Down Expand Up @@ -82,6 +83,7 @@ def convert(self, model, mapper, prop, field_args, db_session=None):
'filters': [],
'default': None,
}
kwargs['description'] = prop.doc

if field_args:
kwargs.update(field_args)
Expand Down Expand Up @@ -112,11 +114,6 @@ def convert(self, model, mapper, prop, field_args, db_session=None):

kwargs['default'] = default

if column.nullable:
kwargs['validators'].append(validators.Optional())
else:
kwargs['validators'].append(validators.Required())

converter = self.get_converter(column)
else:
# We have a property with a direction.
Expand Down Expand Up @@ -150,37 +147,53 @@ class ModelConverter(ModelConverterBase):
def __init__(self, extra_converters=None, use_mro=True):
super(ModelConverter, self).__init__(extra_converters, use_mro=use_mro)

@classmethod
def _nullable_required(cls, column, field_args, **extra):
if column.nullable:
field_args['validators'].append(validators.Optional())
else:
field_args['validators'].append(validators.Required())

@classmethod
def _string_common(cls, column, field_args, **extra):
if isinstance(column.type.length, int) and column.type.length:
field_args['validators'].append(validators.Length(max=column.type.length))

@converts('String') # includes Unicode
def conv_String(self, field_args, **extra):
self._string_common(field_args=field_args, **extra)
def conv_String(self, column, field_args, **extra):
self._string_common(column=column, field_args=field_args, **extra)
self._nullable_required(column=column, field_args=field_args, **extra)
return wtforms_fields.StringField(**field_args)

@converts('Text', 'LargeBinary', 'Binary') # includes UnicodeText
def conv_Text(self, field_args, **extra):
self._string_common(field_args=field_args, **extra)
def conv_Text(self, column, field_args, **extra):
self._string_common(column=column, field_args=field_args, **extra)
self._nullable_required(column=column, field_args=field_args, **extra)
return wtforms_fields.TextAreaField(**field_args)

@converts('Boolean', 'dialects.mssql.base.BIT')
def conv_Boolean(self, field_args, **extra):
def conv_Boolean(self, column, field_args, **extra):
return wtforms_fields.BooleanField(**field_args)

@converts('Date')
def conv_Date(self, field_args, **extra):
def conv_Date(self, column, field_args, **extra):
self._nullable_required(column=column, field_args=field_args, **extra)
return wtforms_fields.DateField(**field_args)

@converts('DateTime')
def conv_DateTime(self, field_args, **extra):
def conv_DateTime(self, column, field_args, **extra):
self._nullable_required(column=column, field_args=field_args, **extra)
return wtforms_fields.DateTimeField(**field_args)

@converts('Enum')
def conv_Enum(self, column, field_args, **extra):
field_args['choices'] = [(e, e) for e in column.type.enums]
return wtforms_fields.SelectField(**field_args)
self._nullable_required(column=column, field_args=field_args, **extra)
if column.type.enum_class is not None:
field_args['enum'] = column.type.enum_class
return EnumSelectField(**field_args)
else:
field_args['choices'] = [(e, e) for e in column.type.enums]
return wtforms_fields.SelectField(**field_args)

@converts('Integer') # includes BigInteger and SmallInteger
def handle_integer_types(self, column, field_args, **extra):
Expand All @@ -196,34 +209,38 @@ def handle_decimal_types(self, column, field_args, **extra):
return wtforms_fields.DecimalField(**field_args)

@converts('dialects.mysql.types.YEAR', 'dialects.mysql.base.YEAR')
def conv_MSYear(self, field_args, **extra):
def conv_MSYear(self, column, field_args, **extra):
self._nullable_required(column=column, field_args=field_args, **extra)
field_args['validators'].append(validators.NumberRange(min=1901, max=2155))
return wtforms_fields.StringField(**field_args)

@converts('dialects.postgresql.base.INET')
def conv_PGInet(self, field_args, **extra):
def conv_PGInet(self, column, field_args, **extra):
self._nullable_required(column=column, field_args=field_args, **extra)
field_args.setdefault('label', 'IP Address')
field_args['validators'].append(validators.IPAddress())
return wtforms_fields.StringField(**field_args)

@converts('dialects.postgresql.base.MACADDR')
def conv_PGMacaddr(self, field_args, **extra):
def conv_PGMacaddr(self, column, field_args, **extra):
self._nullable_required(column=column, field_args=field_args, **extra)
field_args.setdefault('label', 'MAC Address')
field_args['validators'].append(validators.MacAddress())
return wtforms_fields.StringField(**field_args)

@converts('dialects.postgresql.base.UUID')
def conv_PGUuid(self, field_args, **extra):
def conv_PGUuid(self, column, field_args, **extra):
self._nullable_required(column=column, field_args=field_args, **extra)
field_args.setdefault('label', 'UUID')
field_args['validators'].append(validators.UUID())
return wtforms_fields.StringField(**field_args)

@converts('MANYTOONE')
def conv_ManyToOne(self, field_args, **extra):
def conv_ManyToOne(self, column, field_args, **extra):
return QuerySelectField(**field_args)

@converts('MANYTOMANY', 'ONETOMANY')
def conv_ManyToMany(self, field_args, **extra):
def conv_ManyToMany(self, column, field_args, **extra):
return QuerySelectMultipleField(**field_args)


Expand All @@ -238,7 +255,7 @@ def model_fields(model, db_session=None, only=None, exclude=None,
mapper = model._sa_class_manager.mapper
converter = converter or ModelConverter()
field_args = field_args or {}
properties = []
properties = {}

for prop in mapper.iterate_properties:
if getattr(prop, 'columns', None):
Expand All @@ -247,24 +264,25 @@ def model_fields(model, db_session=None, only=None, exclude=None,
elif exclude_pk and prop.columns[0].primary_key:
continue

properties.append((prop.key, prop))
properties[prop.key] = prop

#((p.key, p) for p in mapper.iterate_properties)
if only:
properties = (x for x in properties if x[0] in only)
order = list(only)
properties = {key: properties[key] for key in only}
elif exclude:
properties = (x for x in properties if x[0] not in exclude)
properties = {key: prop for key, prop in properties if key not in exclude}
order = list(properties.keys())

field_dict = {}
for name, prop in properties:
for name, prop in properties.items():
field = converter.convert(
model, mapper, prop,
field_args.get(name), db_session
)
if field is not None:
field_dict[name] = field

return field_dict
return {key: field_dict[key] for key in order}


def model_form(model, db_session=None, base_class=Form, only=None,
Expand Down