diff --git a/docs/configuration.rst b/docs/configuration.rst index 9efafef..e6c13ee 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -272,6 +272,8 @@ Let's say we want to convert all unicode typed properties to TextAreaFields inst type_map = ClassMap({sa.Unicode: TextAreaField}) +In case the type_map dictionary values are not inherited from WTForm field class, they are considered callable functions. These functions will be called with the corresponding column as their only parameter. + .. _custom_base: diff --git a/tests/test_types.py b/tests/test_types.py index 03f5950..981d5eb 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -275,3 +275,24 @@ class Meta: form = ModelTestForm() assert isinstance(form.test_column, TextAreaField) + + def test_override_type_map_with_callable(self): + class ModelTest(self.base): + __tablename__ = 'model_test' + id = sa.Column(sa.Integer, primary_key=True) + test_column_short = sa.Column(sa.Unicode(255), nullable=False) + test_column_long = sa.Column(sa.Unicode(), nullable=False) + + class ModelTestForm(ModelForm): + class Meta: + model = ModelTest + not_null_validator = None + type_map = ClassMap({ + sa.Unicode: lambda column: ( + StringField if column.type.length else TextAreaField + ) + }) + + form = ModelTestForm() + assert isinstance(form.test_column_short, StringField) + assert isinstance(form.test_column_long, TextAreaField) diff --git a/wtforms_alchemy/generator.py b/wtforms_alchemy/generator.py index 5ab1202..b27ec26 100644 --- a/wtforms_alchemy/generator.py +++ b/wtforms_alchemy/generator.py @@ -9,7 +9,13 @@ import sqlalchemy as sa from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy_utils import types -from wtforms import BooleanField, FloatField, PasswordField, TextAreaField +from wtforms import ( + BooleanField, + Field, + FloatField, + PasswordField, + TextAreaField +) from wtforms.widgets import CheckboxInput, TextArea from wtforms_components import ( ColorField, @@ -607,6 +613,10 @@ def get_field_class(self, column): check_type = column.type try: - return self.TYPE_MAP[check_type] + column_type = self.TYPE_MAP[check_type] + if inspect.isclass(column_type) and issubclass(column_type, Field): + return column_type + else: + return column_type(column) except KeyError: raise UnknownTypeException(column)