Skip to content

Commit

Permalink
add a form ident field to the keg-elements base form
Browse files Browse the repository at this point in the history
- hidden field to identify a form in the request
- field key and value have defaults, but may be customized
- excluded from custom ordering (like csrf)
refs #3
  • Loading branch information
guruofgentoo committed Dec 2, 2022
1 parent 60fac30 commit 941e4a7
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 7 deletions.
65 changes: 59 additions & 6 deletions keg_elements/forms/__init__.py
Expand Up @@ -2,10 +2,11 @@
import inspect
import logging
import warnings
from decimal import Decimal
from operator import attrgetter

from blazeutils.strings import case_cw2dash
import flask
from decimal import Decimal
from flask_wtf import FlaskForm as BaseForm
from keg.db import db
import sqlalchemy as sa
Expand All @@ -14,7 +15,7 @@
import six
import wtforms.fields
import wtforms.form
from wtforms.validators import InputRequired, Optional, StopValidation, NumberRange
from wtforms.validators import InputRequired, Optional, StopValidation, NumberRange, AnyOf
from wtforms_alchemy import (
FormGenerator as FormGeneratorBase,
model_form_factory,
Expand Down Expand Up @@ -796,20 +797,31 @@ class MyForm(Form):
field1 = String('field1_label') # Note that we don't use the label in the ordering
field2 = String()
"""
_form_ident_enabled = True
_form_ident_strict = True

def __init__(self, *args, **kwargs):
super(Form, self).__init__(*args, **kwargs)
self._form_level_errors = []
self._errors = None
self.after_init(args, kwargs)

def __init_subclass__(cls):
cls.add_form_ident()
super().__init_subclass__()

def __iter__(self):
order = getattr(self, '_field_order', None)
custom_field_order = getattr(self, '_field_order', None)

if order is None:
if custom_field_order is None:
return super().__iter__()

has_csrf = hasattr(self, 'csrf_token')
order = (['csrf_token'] if has_csrf else []) + list(order)
order = []
if hasattr(self, 'csrf_token'):
order.append('csrf_token')
if self._form_ident_enabled:
order.append(self._form_ident_key())
order.extend(list(custom_field_order))

declared = set(self._fields.keys())
ordered = set(order)
Expand Down Expand Up @@ -871,6 +883,47 @@ def field_errors(self):
)
return self.errors

@classmethod
def add_form_ident(cls):
# may need to clean up from a superclass init, so we have fresh config here
key = cls._form_ident_key()
if hasattr(cls, key):
setattr(cls, key, None)

if not cls._form_ident_enabled:
return

if key.startswith('_'):
raise Exception('Cannot start form ident name with "_", since WTForms will ignore')

validators = []
value = cls._form_ident_value()
if cls._form_ident_strict:
validators.append(AnyOf([value]))

setattr(
cls,
key,
wtforms.fields.HiddenField(
default=value,
validators=validators,
)
)

@classmethod
def _form_ident_key(cls):
"""Field name to embed as a hidden value for form identification. Default is keg_form_ident.
Note: this cannot start with an underscore, or WTForms will ignore the field.
"""
return 'keg_form_ident'

@classmethod
def _form_ident_value(cls):
"""Field value to embed for form identification. Default is class name converted to
dash notation."""
return case_cw2dash(cls.__name__)


BaseModelFormMeta = model_form_meta_factory()

Expand Down
74 changes: 73 additions & 1 deletion keg_elements/tests/test_forms/test_form.py
Expand Up @@ -345,6 +345,7 @@ def test_length_validation_not_applied_for_enums(self):


class FeaturesForm(Form):
_form_ident_enabled = False
name = wtf.StringField(validators=[validators.data_required()])
color = wtf.StringField()

Expand All @@ -355,6 +356,7 @@ class NumbersSubForm(wtf.Form):


class NumbersForm(Form):
_form_ident_enabled = False
numbers = wtf.FieldList(wtf.FormField(NumbersSubForm), min_entries=2)
numbers2 = wtf.FieldList(wtf.StringField('Number'), min_entries=2)

Expand Down Expand Up @@ -561,6 +563,7 @@ def generate_csrf_token(self, token):
return 'token'

class CSRF(Form):
_form_ident_enabled = False
_field_order = ('num2', 'num1',)

class Meta:
Expand All @@ -583,6 +586,18 @@ class OrderedForm(Form):

form = OrderedForm()

assert [x.name for x in form] == ['keg_form_ident', 'num3', 'num1', 'num2']

def test_field_order_no_ident(self):
class OrderedForm(Form):
_form_ident_enabled = False
_field_order = ('num3', 'num1', 'num2',)
num1 = wtf.IntegerField()
num2 = wtf.IntegerField()
num3 = wtf.IntegerField()

form = OrderedForm()

assert [x.name for x in form] == ['num3', 'num1', 'num2']

def test_field_unorder(self):
Expand All @@ -593,7 +608,64 @@ class UnorderedForm(Form):

form = UnorderedForm()

assert [x.name for x in form] == ['num1', 'num2', 'num3']
assert [x.name for x in form] == ['num1', 'num2', 'num3', 'keg_form_ident']


class TestFormIdentValidation(FormBase):
class MyForm(Form):
num1 = wtf.IntegerField()

form_cls = MyForm

def test_form_ident_validated(self):
form = self.assert_invalid(keg_form_ident='foo')
assert form.form_errors == []
assert form.errors == {'keg_form_ident': ['Invalid value, must be one of: my-form.']}
self.assert_valid(keg_form_ident='my-form')

def test_form_ident_validated_custom_key(self):
class TestForm(self.MyForm):
@classmethod
def _form_ident_key(cls):
return 'mycoolfield'

form = self.assert_invalid(form_cls=TestForm, mycoolfield='foo')
assert form.form_errors == []
assert form.errors == {'mycoolfield': ['Invalid value, must be one of: test-form.']}
self.assert_valid(form_cls=TestForm, mycoolfield='test-form')

def test_form_ident_validated_custom_key_invalid(self):
with pytest.raises(Exception, match='Cannot start form ident name with "_"'):
class TestForm(self.MyForm):
@classmethod
def _form_ident_key(cls):
return '_mycoolfield'

def test_form_ident_validated_custom_value(self):
class TestForm(self.MyForm):
@classmethod
def _form_ident_value(cls):
return 'bar'

form = self.assert_invalid(form_cls=TestForm, keg_form_ident='foo')
assert form.form_errors == []
assert form.errors == {'keg_form_ident': ['Invalid value, must be one of: bar.']}
self.assert_valid(form_cls=TestForm, keg_form_ident='bar')

def test_form_ident_not_strict(self):
class TestForm(self.MyForm):
_form_ident_strict = False

self.assert_valid(form_cls=TestForm, keg_form_ident='foo')

def test_form_ident_not_exists(self):
class TestForm(self.MyForm):
_form_ident_enabled = False

form = TestForm()
assert not getattr(TestForm, 'keg_form_ident', None)
assert not getattr(form, 'keg_form_ident', None)
assert 'keg_form_ident' not in form._fields


class TestFormLevelValidation(FormBase):
Expand Down

0 comments on commit 941e4a7

Please sign in to comment.