Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions tests/tests.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,33 @@
from __future__ import unicode_literals, absolute_import

from sqlalchemy import create_engine, ForeignKey, types as sqla_types
import sys

from sqlalchemy import create_engine, ForeignKey, types as sqla_types, __version__ as _sqla_version
from sqlalchemy.schema import MetaData, Table, Column, ColumnDefault
from sqlalchemy.orm import sessionmaker, relationship, backref
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.dialects.postgresql import INET, MACADDR, UUID
from sqlalchemy.dialects.mysql import YEAR
from sqlalchemy.dialects.mssql import BIT

from unittest import TestCase
from unittest import TestCase, skipIf

from wtforms.compat import text_type, iteritems
from wtforms_sqlalchemy.fields import QuerySelectField, QuerySelectMultipleField
from wtforms_sqlalchemy.fields import QuerySelectField, QuerySelectMultipleField, EnumSelectField
from wtforms import Form, fields
from wtforms_sqlalchemy.orm import model_form, ModelConversionError, ModelConverter
from wtforms.validators import Optional, Required, Regexp
from .common import DummyPostData, contains_validator

try:
import enum
except ImportError:
pass


sqla_version = tuple(int(i) for i in _sqla_version.split('.'))



class LazySelect(object):
def __call__(self, field, **kwargs):
Expand Down Expand Up @@ -423,3 +434,27 @@ def test_convert_types(self):
assert isinstance(form.timestamp, fields.DateTimeField)

assert isinstance(form.date, fields.DateField)

@skipIf(
sqla_version < (1, 1) or sys.version_info < (3, 4),
"PEP-435-style enum class support was added in SQLAlchemy 1.1, Python 3.4")
class ModelFormEnumTest(TestCase):
def setUp(self):
Model = declarative_base()
class Fruit(enum.Enum):
orange = 1
banana = 2
apple = 3

class EnumModel(Model):
__tablename__ = "course"
id = Column(sqla_types.Integer, primary_key=True)
favourite_fruit = Column(sqla_types.Enum(Fruit))

self.EnumModel = EnumModel


def test_enum_type(self):
form = model_form(self.EnumModel)()

assert isinstance(form.favourite_fruit, EnumSelectField)
38 changes: 38 additions & 0 deletions wtforms_sqlalchemy/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,41 @@ class QueryCheckboxField(QuerySelectMultipleField):
def get_pk_from_identity(obj):
key = identity_key(instance=obj)[1]
return ':'.join(text_type(x) for x in key)


class EnumSelectField(SelectFieldBase):
widget = widgets.Select()

def __init__(self, *args, enum, members=None, **kwargs):
super().__init__(*args, **kwargs)
self.enum = enum
if members is not None:
self.members = list(members)
else:
self.members = list(enum)
self.members_set = set(self.members)
self.members_map = {member.name: member for member in self.members}

def to_choice(self, member):
return (member.name, str(member))

def iter_choices(self):
for member in self.members:
yield (member.name, str(member), member is self.data)

def process_data(self, value):
if isinstance(value, self.enum):
self.data = value
else:
self.data = None

def process_formdata(self, valuelist):
if valuelist:
try:
self.data = self.members_map[valuelist[0]]
except (ValueError, TypeError, KeyError):
raise ValueError(self.gettext('Invalid Choice: could not coerce'))

def pre_validate(self, form):
if self.data not in self.members:
raise ValueError(self.gettext('Not a valid choice'))
11 changes: 8 additions & 3 deletions wtforms_sqlalchemy/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from wtforms import validators, fields as wtforms_fields
from wtforms.form import Form
from .fields import QuerySelectField, QuerySelectMultipleField

from .fields import EnumSelectField, QuerySelectField, QuerySelectMultipleField

__all__ = (
'model_fields', 'model_form',
Expand Down Expand Up @@ -179,8 +180,12 @@ def conv_DateTime(self, field_args, **extra):

@converts('Enum')
def conv_Enum(self, column, field_args, **extra):
field_args['choices'] = [(e, e) for e in column.type.enums]
return wtforms_fields.SelectField(**field_args)
if column.type.enum_class is not None:
field_args['enum'] = column.type.enum_class
return EnumSelectField(**field_args)
else:
field_args['choices'] = [(e, e) for e in column.type.enums]
return wtforms_fields.SelectField(**field_args)

@converts('Integer') # includes BigInteger and SmallInteger
def handle_integer_types(self, column, field_args, **extra):
Expand Down