diff --git a/keg_elements/forms/__init__.py b/keg_elements/forms/__init__.py index 7b04636..1f2177a 100644 --- a/keg_elements/forms/__init__.py +++ b/keg_elements/forms/__init__.py @@ -331,10 +331,18 @@ def __init__(self, label=None, orm_cls=None, label_attr=None, fk_attr='id', query_filter=None, coerce=_not_given, **kwargs): label = self.field_label_modifier(label) self.orm_cls = orm_cls + self.label_attr = label_attr if self.label_attr is None: self.label_attr = self.get_best_label_attr() + if self.label_attr: + # compute this once and store on the object, rather than for each option + self.label_attr_name = self.compute_option_attr_name(self.label_attr) + self.fk_attr = fk_attr + if self.fk_attr: + self.fk_attr_name = self.compute_option_attr_name(self.fk_attr) + self.query_filter = query_filter if not self.fk_attr and not coerce: def coerce_to_orm_obj(value): @@ -350,6 +358,19 @@ def coerce_to_orm_obj(value): super().__init__(label=label, choices=self.get_choices, coerce=coerce, **kwargs) + def compute_option_attr_name(self, base_attr): + """To access the label value on an option in the result set, we need to know what + attribute to use. Based on ``label_attr``, which can be string, SA ORM attr, + SA hybrid attr, etc., determine a sane attribute.""" + retval = base_attr + + if isinstance(base_attr, sa.orm.InstrumentedAttribute): + retval = base_attr.name + else: + retval = getattr(base_attr, '__name__', str(base_attr)) + + return retval + def field_label_modifier(self, label): """Modifies the label to something more human-friendly. @@ -366,11 +387,14 @@ def build_query(self): return query def query_base(self): - return self.orm_cls.query.order_by(self.label_attr) + orm_label_attr = self.label_attr + if isinstance(self.label_attr, str): + orm_label_attr = getattr(self.orm_cls, self.label_attr) + return self.orm_cls.query.order_by(orm_label_attr) def get_data_filter(self): if self.fk_attr: - return getattr(self.orm_cls, self.fk_attr) == self.data + return getattr(self.orm_cls, self.fk_attr_name) == self.data else: return self.orm_cls.id == self.data.id @@ -407,7 +431,7 @@ def get_best_label_attr(self): def get_option_label(self, obj): if self.label_attr: - return getattr(obj, self.label_attr) + return getattr(obj, self.label_attr_name) return str(obj) @@ -416,7 +440,7 @@ def get_choices(self): def get_value(obj): if self.fk_attr: - return str(getattr(obj, self.fk_attr)) + return str(getattr(obj, self.fk_attr_name)) return str(obj.id) @@ -437,9 +461,6 @@ class RelationshipField(RelationshipFieldBase, SelectField): orm_cls (class): Model class of the relationship attribute. Used to query records for populating select options. - relationship_attr (str): Name of the attribute on form model that refers to - the relationship object. Typically this is a foreign key ID. - label_attr (str): Name of attribute on relationship class to use for select option labels. @@ -478,9 +499,6 @@ class RelationshipMultipleField(RelationshipFieldBase, SelectMultipleField): orm_cls (class): Model class of the relationship attribute. Used to query records for populating select options. - relationship_attr (str): Name of the collection on form model that refers to - the relationship object. - label_attr (str): Name of attribute on relationship class to use for select option labels. diff --git a/keg_elements/tests/test_forms/test_form.py b/keg_elements/tests/test_forms/test_form.py index 6d7066e..d7886de 100644 --- a/keg_elements/tests/test_forms/test_form.py +++ b/keg_elements/tests/test_forms/test_form.py @@ -944,7 +944,6 @@ def test_options_include_form_obj_value(self): ) self.assert_object_options(self.get_field(form), [foo_thing, thing1]) - assert self.get_field(form).data == self.coerce(foo_thing.id) def test_options_sorted(self): thing_b = ents.Thing.testing_create(name='BBB') @@ -1017,6 +1016,56 @@ def test_coerce_formdata(self): assert form.thing.data == thing +class TestOrmRelationshipOrmAttr(TestOrmRelationship): + """Test a relationship field that uses an ORM attr for the label.""" + + def create_relationship(self, query_filter=None): + return RelationshipField('Thing', ents.Thing, ents.Thing.name, query_filter=query_filter, + fk_attr=None) + + +class TestOrmRelationshipOrmFkAttr(TestOrmRelationship): + """Test a relationship field that uses an ORM attr for the fk.""" + + def create_relationship(self, query_filter=None): + return RelationshipField('Thing', ents.Thing, 'name', query_filter=query_filter, + fk_attr=ents.Thing.id) + + def create_form(self, relationship, **kwargs): + class RelatedThingForm(ModelForm): + thing_id = relationship + + class Meta: + model = ents.RelatedThing + + return RelatedThingForm(**kwargs) + + def get_field(self, form): + return form.thing_id + + def test_coerce_formdata(self): + thing = ents.Thing.testing_create() + form = self.create_form(self.create_relationship(), + formdata=MultiDict({'thing_id': str(thing.id)})) + assert form.thing_id.data == thing.id + + +class TestOrmRelationshipOrmHybridAttr(TestOrmRelationshipOrmAttr): + """Test a relationship field that uses an ORM hybrid property attr for the label.""" + + def create_relationship(self, query_filter=None): + return RelationshipField('Thing', ents.Thing, ents.Thing.hybrid_name, + query_filter=query_filter, fk_attr=None) + + +class TestOrmRelationshipOrmHybridFkAttr(TestOrmRelationshipOrmFkAttr): + """Test a relationship field that uses an ORM hybrid property for the fk.""" + + def create_relationship(self, query_filter=None): + return RelationshipField('Thing', ents.Thing, 'name', query_filter=query_filter, + fk_attr=ents.Thing.hybrid_id) + + class TestRelationshipFieldGenerator: def setup_method(self): ents.Thing.delete_cascaded() diff --git a/kegel_app/model/entities.py b/kegel_app/model/entities.py index 49b6c84..80c882a 100644 --- a/kegel_app/model/entities.py +++ b/kegel_app/model/entities.py @@ -56,6 +56,22 @@ def name_and_color(self): def name_and_color(cls): return cls.name + sa.sql.literal('-') + cls.color + @sa.ext.hybrid.hybrid_property + def hybrid_name(self): + return self.name + + @hybrid_name.expression + def hybrid_name(cls): + return cls.name + + @sa.ext.hybrid.hybrid_property + def hybrid_id(self): + return self.id + + @hybrid_id.expression + def hybrid_id(cls): + return cls.id + @classmethod def random_color(cls): return 'blue'