Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class F(Form):
.filter(self.Test.id == 1, self.Test.id != 1)
.all()
)
self.assertEqual(form.a(), [])
self.assertEqual(form.a(), [('__None', 'Select...', True)])


def test_with_query_factory(self):
Expand All @@ -118,16 +118,16 @@ class F(Form):

form = F()
self.assertEqual(form.a.data, None)
self.assertEqual(form.a(), [('1', 'apple', False), ('2', 'banana', False)])
self.assertEqual(form.a(), [('__None', 'Select...', True), ('1', 'apple', False), ('2', 'banana', False)])
self.assertEqual(form.b.data, None)
self.assertEqual(form.b(), [('__None', '', True), ('hello1', 'apple', False), ('hello2', 'banana', False)])
self.assertEqual(form.b(), [('__None', 'Select...', True), ('hello1', 'apple', False), ('hello2', 'banana', False)])
self.assertFalse(form.validate())

form = F(DummyPostData(a=['1'], b=['hello2']))
self.assertEqual(form.a.data.id, 1)
self.assertEqual(form.a(), [('1', 'apple', True), ('2', 'banana', False)])
self.assertEqual(form.b.data.baz, 'banana')
self.assertEqual(form.b(), [('__None', '', False), ('hello1', 'apple', False), ('hello2', 'banana', True)])
self.assertEqual(form.b(), [('__None', 'Select...', False), ('hello1', 'apple', False), ('hello2', 'banana', True)])
self.assertTrue(form.validate())

# Make sure the query is cached
Expand All @@ -152,7 +152,7 @@ class F(Form):
.filter(self.Test.id == 1, self.Test.id != 1)
.all()
)
self.assertEqual(form.a(), [])
self.assertEqual(form.a(), [('__None', 'Select...', True)])


class QuerySelectMultipleFieldTest(TestBase):
Expand Down
15 changes: 9 additions & 6 deletions wtforms_sqlalchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import operator

from wtforms import widgets
from wtforms.compat import text_type, string_types
from wtforms.compat import string_types, text_type
from wtforms.fields import SelectFieldBase
from wtforms.validators import ValidationError

Expand Down Expand Up @@ -57,7 +57,7 @@ class QuerySelectField(SelectFieldBase):

def __init__(self, label=None, validators=None, query_factory=None,
get_pk=None, get_label=None, allow_blank=False,
blank_text='', **kwargs):
blank_value='__None', blank_text='Select...', **kwargs):
super(QuerySelectField, self).__init__(label, validators, **kwargs)
self.query_factory = query_factory

Expand All @@ -76,6 +76,7 @@ def __init__(self, label=None, validators=None, query_factory=None,
self.get_label = get_label

self.allow_blank = allow_blank
self.blank_value = blank_value
self.blank_text = blank_text
self.query = None
self._object_list = None
Expand Down Expand Up @@ -105,15 +106,15 @@ def _get_object_list(self):
return self._object_list

def iter_choices(self):
if self.allow_blank:
yield ('__None', self.blank_text, self.data is None)
if self.allow_blank or self.data is None:
yield (self.blank_value, self.gettext(self.blank_text), self.data is None)

for pk, obj in self._get_object_list():
yield (pk, self.get_label(obj), obj == self.data)

def process_formdata(self, valuelist):
if valuelist:
if self.allow_blank and valuelist[0] == '__None':
if self.allow_blank and valuelist[0] == self.blank_value:
self.data = None
else:
self._data = None
Expand All @@ -127,8 +128,10 @@ def pre_validate(self, form):
break
else:
raise ValidationError(self.gettext('Not a valid choice'))
elif self._formdata or not self.allow_blank:
elif self._formdata is not None:
raise ValidationError(self.gettext('Not a valid choice'))
elif not self.allow_blank:
raise ValidationError(self.gettext('This field is required'))


class QuerySelectMultipleField(QuerySelectField):
Expand Down