Skip to content

Commit

Permalink
Improved model_form that now requires a db_session. Fixed Unique vali…
Browse files Browse the repository at this point in the history
…dator.
  • Loading branch information
jeanphix committed Oct 30, 2011
1 parent 4ae56d1 commit 623174d
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 26 deletions.
38 changes: 38 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
WTAlchemy
-----------
WTForms SQAlchemy extension.
"""
from setuptools import setup, find_packages


setup(
name='WTAlchemy',
version='0.1b',
url='https://github.com/jean-philippe/WTAlchemy',
license='mit',
author='Jean-Philippe Serafin',
author_email='serafinjp@gmail.com',
description='WTForms SQAlchemy extension',
long_description=__doc__,
data_files=[('', ['README.rst'])],
packages=find_packages(),
include_package_data=True,
zip_safe=False,
platforms='any',
install_requires=[
'WTForms==0.6.3',
],
classifiers=[
'Development Status :: 4 - Beta',
'Environment :: Web Environment',
'Intended Audience :: Developers',
'License :: OSI Approved :: MIT License',
'Operating System :: OS Independent',
'Programming Language :: Python',
'Topic :: Internet :: WWW/HTTP :: Dynamic Content',
'Topic :: Software Development :: Libraries :: Python Modules'
],
)
33 changes: 22 additions & 11 deletions tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ class School(Model):
class Student(Model):
__tablename__ = "student"
id = Column(Integer, primary_key=True)
full_name = Column(String(255), nullable=False)
full_name = Column(String(255), nullable=False, unique=True)
dob = Column(Date(), nullable=True)
current_school_id = Column(Integer, ForeignKey(School.id),
nullable=False)
Expand All @@ -217,46 +217,57 @@ class Student(Model):
self.School = School
self.Student = Student

engine = create_engine('sqlite:///:memory:', echo=False)
Session = sessionmaker(bind=engine)
self.metadata = Model.metadata
self.metadata.create_all(bind=engine)
self.sess = Session()

def test_nullable_field(self):
student_form = model_form(self.Student)()
student_form = model_form(self.Student, self.sess)()
self.assertTrue(issubclass(Optional,
student_form._fields['dob'].validators[0].__class__))

def test_required_field(self):
student_form = model_form(self.Student)()
student_form = model_form(self.Student, self.sess)()
self.assertTrue(issubclass(Required,
student_form._fields['full_name'].validators[0].__class__))

def test_unique_field(self):
student_form = model_form(self.Student, self.sess)()
self.assertTrue(issubclass(Unique,
student_form._fields['full_name'].validators[1].__class__))

def test_include_pk(self):
form_class = model_form(self.Student, exclude_pk=False)
form_class = model_form(self.Student, self.sess, exclude_pk=False)
student_form = form_class()
self.assertIn('id', student_form._fields)

def test_exclude_pk(self):
form_class = model_form(self.Student, exclude_pk=True)
form_class = model_form(self.Student, self.sess, exclude_pk=True)
student_form = form_class()
self.assertNotIn('id', student_form._fields)

def test_exclude_fk(self):
student_form = model_form(self.Student)()
student_form = model_form(self.Student, self.sess)()
self.assertNotIn('current_school_id', student_form._fields)

def test_include_fk(self):
student_form = model_form(self.Student, exclude_fk=False)()
student_form = model_form(self.Student, self.sess, exclude_fk=False)()
self.assertIn('current_school_id', student_form._fields)

def test_convert_many_to_one(self):
student_form = model_form(self.Student)()
student_form = model_form(self.Student, self.sess)()
self.assertTrue(issubclass(QuerySelectField,
student_form._fields['current_school'].__class__))

def test_convert_one_to_many(self):
school_form = model_form(self.School)()
school_form = model_form(self.School, self.sess)()
self.assertTrue(issubclass(QuerySelectMultipleField,
school_form._fields['students'].__class__))

def test_convert_many_to_many(self):
student_form = model_form(self.Student)()
student_form = model_form(self.Student, self.sess)()
self.assertTrue(issubclass(QuerySelectMultipleField,
student_form._fields['courses'].__class__))

Expand All @@ -282,7 +293,7 @@ class User(Model):
class UserForm(Form):
username = TextField('Username', [
Length(min=4, max=25),
Unique(lambda: self.sess, User.username)
Unique(lambda: self.sess, User, User.username)
])

self.UserForm = UserForm
Expand Down
30 changes: 23 additions & 7 deletions wtalchemy/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from wtforms.form import Form
from wtalchemy.fields import QuerySelectField
from wtalchemy.fields import QuerySelectMultipleField
from wtalchemy.validators import Unique


__all__ = (
Expand Down Expand Up @@ -37,7 +38,7 @@ def __init__(self, converters, use_mro=True):

self.converters = converters

def convert(self, model, mapper, prop, field_args):
def convert(self, model, db_session, mapper, prop, field_args):
if not isinstance(prop, sqlalchemy.orm.properties.ColumnProperty) and \
not isinstance(prop,
sqlalchemy.orm.properties.RelationshipProperty):
Expand Down Expand Up @@ -78,6 +79,10 @@ def convert(self, model, mapper, prop, field_args):
else:
kwargs['validators'].append(validators.Required())

if column.unique:
kwargs['validators'].append(Unique(lambda: db_session, model,
column))

if self.use_mro:
types = inspect.getmro(type(column.type))
else:
Expand Down Expand Up @@ -110,7 +115,7 @@ def convert(self, model, mapper, prop, field_args):

kwargs.update({
'allow_blank': nullable,
'query_factory': lambda: foreign_model.query.all()
'query_factory': lambda: db_session.query(foreign_model).all()
})

converter = self.converters[prop.direction.name]
Expand Down Expand Up @@ -189,7 +194,7 @@ def conv_ManyToOne(self, field_args, **extra):
return QuerySelectField(**field_args)


def model_fields(model, only=None, exclude=None, field_args=None,
def model_fields(model, db_session, only=None, exclude=None, field_args=None,
converter=None):
"""
Generate a dictionary of fields for a given SQLAlchemy model.
Expand All @@ -211,14 +216,15 @@ def model_fields(model, only=None, exclude=None, field_args=None,

field_dict = {}
for name, prop in properties:
field = converter.convert(model, mapper, prop, field_args.get(name))
field = converter.convert(model, db_session, mapper, prop,
field_args.get(name))
if field is not None:
field_dict[name] = field

return field_dict


def model_form(model, base_class=Form, only=None, exclude=None,
def model_form(model, db_session, base_class=Form, only=None, exclude=None,
field_args=None, converter=None, exclude_pk=True, exclude_fk=True,
type_name=None):
"""
Expand All @@ -230,6 +236,8 @@ def model_form(model, base_class=Form, only=None, exclude=None,
:param model:
A SQLAlchemy mapped model class.
:param db_session:
A SQLAlchemy Session.
:param base_class:
Base form class to extend from. Must be a ``wtforms.Form`` subclass.
:param only:
Expand All @@ -251,6 +259,13 @@ def model_form(model, base_class=Form, only=None, exclude=None,
:param type_name:
An optional string to set returned type name.
"""
class ModelForm(base_class):
"""Sets object as form attribute."""
def __init__(self, *args, **kwargs):
if 'obj' in kwargs:
self._obj = kwargs['obj']
super(ModelForm, self).__init__(*args, **kwargs)

if not exclude:
exclude = []
model_mapper = model.__mapper__
Expand All @@ -264,5 +279,6 @@ def model_form(model, base_class=Form, only=None, exclude=None,
for pair in prop.local_remote_pairs:
exclude.append(pair[0].key)
type_name = type_name or model.__name__ + 'Form'
field_dict = model_fields(model, only, exclude, field_args, converter)
return type(type_name, (base_class, ), field_dict)
field_dict = model_fields(model, db_session, only, exclude, field_args,
converter)
return type(type_name, (ModelForm, ), field_dict)
18 changes: 10 additions & 8 deletions wtalchemy/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@ class Unique(object):
"""
field_flags = ('unique', )

def __init__(self, get_db_session, field, message=None):
self.get_db_session = get_db_session
self.field = field
def __init__(self, get_session, model, column, message=None):
self.get_session = get_session
self.model = model
self.column = column
self.message = message

def __call__(self, form, field):
try:
self.get_db_session().query(self.field)\
.filter(self.field == field.data).one()
if self.message is None:
self.message = field.gettext(u'Allready exists.')
raise ValidationError(self.message)
obj = self.get_session().query(self.model)\
.filter(self.column == field.data).one()
if not hasattr(form, '_obj') or not form._obj == obj:
if self.message is None:
self.message = field.gettext(u'Allready exists.')
raise ValidationError(self.message)
except NoResultFound:
pass

0 comments on commit 623174d

Please sign in to comment.