From a57c4312e9000f15619299e91018993e4bf8d39c Mon Sep 17 00:00:00 2001 From: Tim Heap Date: Fri, 24 May 2019 16:57:56 +1000 Subject: [PATCH 1/2] Use specific EnumSelectField for PEP-435 enum columns --- wtforms_sqlalchemy/fields.py | 38 ++++++++++++++++++++++++++++++++++++ wtforms_sqlalchemy/orm.py | 11 ++++++++--- 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/wtforms_sqlalchemy/fields.py b/wtforms_sqlalchemy/fields.py index da5cb47..ab2bbb9 100644 --- a/wtforms_sqlalchemy/fields.py +++ b/wtforms_sqlalchemy/fields.py @@ -202,3 +202,41 @@ class QueryCheckboxField(QuerySelectMultipleField): def get_pk_from_identity(obj): key = identity_key(instance=obj)[1] return ':'.join(text_type(x) for x in key) + + +class EnumSelectField(SelectFieldBase): + widget = widgets.Select() + + def __init__(self, *args, enum, members=None, **kwargs): + super().__init__(*args, **kwargs) + self.enum = enum + if members is not None: + self.members = list(members) + else: + self.members = list(enum) + self.members_set = set(self.members) + self.members_map = {member.name: member for member in self.members} + + def to_choice(self, member): + return (member.name, str(member)) + + def iter_choices(self): + for member in self.members: + yield (member.name, str(member), member is self.data) + + def process_data(self, value): + if isinstance(value, self.enum): + self.data = value + else: + self.data = None + + def process_formdata(self, valuelist): + if valuelist: + try: + self.data = self.members_map[valuelist[0]] + except (ValueError, TypeError, KeyError): + raise ValueError(self.gettext('Invalid Choice: could not coerce')) + + def pre_validate(self, form): + if self.data not in self.members: + raise ValueError(self.gettext('Not a valid choice')) diff --git a/wtforms_sqlalchemy/orm.py b/wtforms_sqlalchemy/orm.py index 704abdc..0e2f63d 100644 --- a/wtforms_sqlalchemy/orm.py +++ b/wtforms_sqlalchemy/orm.py @@ -7,7 +7,8 @@ from wtforms import validators, fields as wtforms_fields from wtforms.form import Form -from .fields import QuerySelectField, QuerySelectMultipleField + +from .fields import EnumSelectField, QuerySelectField, QuerySelectMultipleField __all__ = ( 'model_fields', 'model_form', @@ -179,8 +180,12 @@ def conv_DateTime(self, field_args, **extra): @converts('Enum') def conv_Enum(self, column, field_args, **extra): - field_args['choices'] = [(e, e) for e in column.type.enums] - return wtforms_fields.SelectField(**field_args) + if column.type.enum_class is not None: + field_args['enum'] = column.type.enum_class + return EnumSelectField(**field_args) + else: + field_args['choices'] = [(e, e) for e in column.type.enums] + return wtforms_fields.SelectField(**field_args) @converts('Integer') # includes BigInteger and SmallInteger def handle_integer_types(self, column, field_args, **extra): From 34d42ce7fadef5fd4421ccefb7dbc08bc127d37d Mon Sep 17 00:00:00 2001 From: Tim Heap Date: Wed, 29 Jan 2020 11:45:47 +1100 Subject: [PATCH 2/2] Only test EnumSelectField on compatible versions --- tests/tests.py | 41 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/tests/tests.py b/tests/tests.py index 7b5bce7..6cffc2d 100755 --- a/tests/tests.py +++ b/tests/tests.py @@ -1,6 +1,8 @@ from __future__ import unicode_literals, absolute_import -from sqlalchemy import create_engine, ForeignKey, types as sqla_types +import sys + +from sqlalchemy import create_engine, ForeignKey, types as sqla_types, __version__ as _sqla_version from sqlalchemy.schema import MetaData, Table, Column, ColumnDefault from sqlalchemy.orm import sessionmaker, relationship, backref from sqlalchemy.ext.declarative import declarative_base @@ -8,15 +10,24 @@ from sqlalchemy.dialects.mysql import YEAR from sqlalchemy.dialects.mssql import BIT -from unittest import TestCase +from unittest import TestCase, skipIf from wtforms.compat import text_type, iteritems -from wtforms_sqlalchemy.fields import QuerySelectField, QuerySelectMultipleField +from wtforms_sqlalchemy.fields import QuerySelectField, QuerySelectMultipleField, EnumSelectField from wtforms import Form, fields from wtforms_sqlalchemy.orm import model_form, ModelConversionError, ModelConverter from wtforms.validators import Optional, Required, Regexp from .common import DummyPostData, contains_validator +try: + import enum +except ImportError: + pass + + +sqla_version = tuple(int(i) for i in _sqla_version.split('.')) + + class LazySelect(object): def __call__(self, field, **kwargs): @@ -423,3 +434,27 @@ def test_convert_types(self): assert isinstance(form.timestamp, fields.DateTimeField) assert isinstance(form.date, fields.DateField) + +@skipIf( + sqla_version < (1, 1) or sys.version_info < (3, 4), + "PEP-435-style enum class support was added in SQLAlchemy 1.1, Python 3.4") +class ModelFormEnumTest(TestCase): + def setUp(self): + Model = declarative_base() + class Fruit(enum.Enum): + orange = 1 + banana = 2 + apple = 3 + + class EnumModel(Model): + __tablename__ = "course" + id = Column(sqla_types.Integer, primary_key=True) + favourite_fruit = Column(sqla_types.Enum(Fruit)) + + self.EnumModel = EnumModel + + + def test_enum_type(self): + form = model_form(self.EnumModel)() + + assert isinstance(form.favourite_fruit, EnumSelectField)