Skip to content

Commit

Permalink
handle SA ORM attributes and hybrid properties for relationship form …
Browse files Browse the repository at this point in the history
…fields

refs #174
  • Loading branch information
guruofgentoo committed Dec 2, 2022
1 parent 9ecb616 commit 97244de
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 11 deletions.
38 changes: 28 additions & 10 deletions keg_elements/forms/__init__.py
Expand Up @@ -331,10 +331,18 @@ def __init__(self, label=None, orm_cls=None, label_attr=None, fk_attr='id',
query_filter=None, coerce=_not_given, **kwargs):
label = self.field_label_modifier(label)
self.orm_cls = orm_cls

self.label_attr = label_attr
if self.label_attr is None:
self.label_attr = self.get_best_label_attr()
if self.label_attr:
# compute this once and store on the object, rather than for each option
self.label_attr_name = self.compute_option_attr_name(self.label_attr)

self.fk_attr = fk_attr
if self.fk_attr:
self.fk_attr_name = self.compute_option_attr_name(self.fk_attr)

self.query_filter = query_filter
if not self.fk_attr and not coerce:
def coerce_to_orm_obj(value):
Expand All @@ -350,6 +358,19 @@ def coerce_to_orm_obj(value):

super().__init__(label=label, choices=self.get_choices, coerce=coerce, **kwargs)

def compute_option_attr_name(self, base_attr):
"""To access the label value on an option in the result set, we need to know what
attribute to use. Based on ``label_attr``, which can be string, SA ORM attr,
SA hybrid attr, etc., determine a sane attribute."""
retval = base_attr

if isinstance(base_attr, sa.orm.InstrumentedAttribute):
retval = base_attr.name
else:
retval = getattr(base_attr, '__name__', str(base_attr))

return retval

def field_label_modifier(self, label):
"""Modifies the label to something more human-friendly.
Expand All @@ -366,11 +387,14 @@ def build_query(self):
return query

def query_base(self):
return self.orm_cls.query.order_by(self.label_attr)
orm_label_attr = self.label_attr
if isinstance(self.label_attr, str):
orm_label_attr = getattr(self.orm_cls, self.label_attr)
return self.orm_cls.query.order_by(orm_label_attr)

def get_data_filter(self):
if self.fk_attr:
return getattr(self.orm_cls, self.fk_attr) == self.data
return getattr(self.orm_cls, self.fk_attr_name) == self.data
else:
return self.orm_cls.id == self.data.id

Expand Down Expand Up @@ -407,7 +431,7 @@ def get_best_label_attr(self):

def get_option_label(self, obj):
if self.label_attr:
return getattr(obj, self.label_attr)
return getattr(obj, self.label_attr_name)

return str(obj)

Expand All @@ -416,7 +440,7 @@ def get_choices(self):

def get_value(obj):
if self.fk_attr:
return str(getattr(obj, self.fk_attr))
return str(getattr(obj, self.fk_attr_name))

return str(obj.id)

Expand All @@ -437,9 +461,6 @@ class RelationshipField(RelationshipFieldBase, SelectField):
orm_cls (class): Model class of the relationship attribute. Used to query
records for populating select options.
relationship_attr (str): Name of the attribute on form model that refers to
the relationship object. Typically this is a foreign key ID.
label_attr (str): Name of attribute on relationship class to use for select
option labels.
Expand Down Expand Up @@ -478,9 +499,6 @@ class RelationshipMultipleField(RelationshipFieldBase, SelectMultipleField):
orm_cls (class): Model class of the relationship attribute. Used to query
records for populating select options.
relationship_attr (str): Name of the collection on form model that refers to
the relationship object.
label_attr (str): Name of attribute on relationship class to use for select
option labels.
Expand Down
51 changes: 50 additions & 1 deletion keg_elements/tests/test_forms/test_form.py
Expand Up @@ -944,7 +944,6 @@ def test_options_include_form_obj_value(self):
)

self.assert_object_options(self.get_field(form), [foo_thing, thing1])
assert self.get_field(form).data == self.coerce(foo_thing.id)

def test_options_sorted(self):
thing_b = ents.Thing.testing_create(name='BBB')
Expand Down Expand Up @@ -1017,6 +1016,56 @@ def test_coerce_formdata(self):
assert form.thing.data == thing


class TestOrmRelationshipOrmAttr(TestOrmRelationship):
"""Test a relationship field that uses an ORM attr for the label."""

def create_relationship(self, query_filter=None):
return RelationshipField('Thing', ents.Thing, ents.Thing.name, query_filter=query_filter,
fk_attr=None)


class TestOrmRelationshipOrmFkAttr(TestOrmRelationship):
"""Test a relationship field that uses an ORM attr for the fk."""

def create_relationship(self, query_filter=None):
return RelationshipField('Thing', ents.Thing, 'name', query_filter=query_filter,
fk_attr=ents.Thing.id)

def create_form(self, relationship, **kwargs):
class RelatedThingForm(ModelForm):
thing_id = relationship

class Meta:
model = ents.RelatedThing

return RelatedThingForm(**kwargs)

def get_field(self, form):
return form.thing_id

def test_coerce_formdata(self):
thing = ents.Thing.testing_create()
form = self.create_form(self.create_relationship(),
formdata=MultiDict({'thing_id': str(thing.id)}))
assert form.thing_id.data == thing.id


class TestOrmRelationshipOrmHybridAttr(TestOrmRelationshipOrmAttr):
"""Test a relationship field that uses an ORM hybrid property attr for the label."""

def create_relationship(self, query_filter=None):
return RelationshipField('Thing', ents.Thing, ents.Thing.hybrid_name,
query_filter=query_filter, fk_attr=None)


class TestOrmRelationshipOrmHybridFkAttr(TestOrmRelationshipOrmFkAttr):
"""Test a relationship field that uses an ORM hybrid property for the fk."""

def create_relationship(self, query_filter=None):
return RelationshipField('Thing', ents.Thing, 'name', query_filter=query_filter,
fk_attr=ents.Thing.hybrid_id)


class TestRelationshipFieldGenerator:
def setup_method(self):
ents.Thing.delete_cascaded()
Expand Down
16 changes: 16 additions & 0 deletions kegel_app/model/entities.py
Expand Up @@ -56,6 +56,22 @@ def name_and_color(self):
def name_and_color(cls):
return cls.name + sa.sql.literal('-') + cls.color

@sa.ext.hybrid.hybrid_property
def hybrid_name(self):
return self.name

@hybrid_name.expression
def hybrid_name(cls):
return cls.name

@sa.ext.hybrid.hybrid_property
def hybrid_id(self):
return self.id

@hybrid_id.expression
def hybrid_id(cls):
return cls.id

@classmethod
def random_color(cls):
return 'blue'
Expand Down

0 comments on commit 97244de

Please sign in to comment.