From 26bce8ebea2c41476dc9ed8bfc9f1655870c61cb Mon Sep 17 00:00:00 2001 From: Seth Padowitz Date: Sat, 21 Mar 2015 11:25:03 -0400 Subject: [PATCH] Reworked merge_declarative_args() in terms of __{global,local}_{mapper,table}_args__. --- alchy/model.py | 26 +++++++------ alchy/utils.py | 35 ++++++++--------- tests/test_model.py | 92 ++++++++++++++++++++++++++++----------------- 3 files changed, 87 insertions(+), 66 deletions(-) diff --git a/alchy/model.py b/alchy/model.py index a92d4ae..c7b69d7 100644 --- a/alchy/model.py +++ b/alchy/model.py @@ -2,7 +2,11 @@ """ from sqlalchemy import inspect, orm -from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta +from sqlalchemy.ext.declarative import ( + declarative_base, + DeclarativeMeta, + declared_attr, +) from . import query, events from .utils import ( @@ -11,7 +15,7 @@ camelcase_to_underscore, get_mapper_class, merge_mapper_args, - merge_table_args + merge_table_args, ) from ._compat import iteritems @@ -70,16 +74,6 @@ def __init__(cls, name, bases, dct): base_dcts = [dct] + [base.__dict__ for base in bases] - # Merge __mapper_args__ from all base classes. - __mapper_args__ = merge_mapper_args(cls, base_dcts) - if __mapper_args__: - cls.__mapper_args__ = __mapper_args__ - - # Merge __table_args__ from all base classes. - __table_args__ = merge_table_args(cls, base_dcts) - if __table_args__: - cls.__table_args__ = __table_args__ - class ModelBase(object): """Base class for creating a declarative base for models. @@ -131,6 +125,14 @@ class User(Model): query_class = query.QueryModel query = None + @declared_attr + def __mapper_args__(cls): # pylint: disable=no-self-argument + return merge_mapper_args(cls) + + @declared_attr + def __table_args__(cls): # pylint: disable=no-self-argument + return merge_table_args(cls) + def __init__(self, *args, **kargs): """Initialize model instance by calling :meth:`update`.""" self.update(*args, **kargs) diff --git a/alchy/utils.py b/alchy/utils.py index d8f38d6..ca438ef 100644 --- a/alchy/utils.py +++ b/alchy/utils.py @@ -92,11 +92,15 @@ def get_mapper_class(model, field): return mapper_class(getattr(model, field)) -def merge_declarative_args(cls, base_dcts, config_key): - """Given a list of base dicts, merge declarative args identified by - `config_key` into a single configuration object. +def merge_declarative_args(cls, global_config_key, local_config_key): + """Merge declarative args for class `cls` + identified by `global_config_key` and `local_config_key` + into a consolidated (tuple, dict). """ - configs = [base.get(config_key) for base in reversed(base_dcts)] + configs = [base.__dict__.get(global_config_key) + for base in reversed(cls.mro())] + configs.append(cls.__dict__.get(local_config_key)) + args = [] kargs = {} @@ -125,31 +129,24 @@ def merge_declarative_args(cls, base_dcts, config_key): return (args, kargs) -def merge_mapper_args(cls, base_dcts): +def merge_mapper_args(cls): """Merge `__mapper_args__` from all base dictionaries and `__local_mapper_args__` from first base into single inherited object. """ - _, kargs = merge_declarative_args(cls, base_dcts, '__mapper_args__') - _, local_kargs = merge_declarative_args(cls, - base_dcts[:1], - '__local_mapper_args__') - - kargs.update(local_kargs) + _, kargs = merge_declarative_args(cls, + '__global_mapper_args__', + '__local_mapper_args__') return kargs -def merge_table_args(cls, base_dcts): +def merge_table_args(cls): """Merge `__table_args__` from all base dictionaries and `__local_table_args__` from first base into single inherited object. """ - args, kargs = merge_declarative_args(cls, base_dcts, '__table_args__') - local_args, local_kargs = merge_declarative_args(cls, - base_dcts[:1], - '__local_table_args__') - - args = unique(args + local_args) - kargs.update(local_kargs) + args, kargs = merge_declarative_args(cls, + '__global_table_args__', + '__local_table_args__') # Append kargs onto end of args to adhere to SQLAlchemy requirements. args.append(kargs) diff --git a/tests/test_model.py b/tests/test_model.py index 9cc67eb..337818f 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -421,33 +421,63 @@ def test_multiple_primary_keys(self): self.assertEqual(MultiplePrimaryKey.primary_key(), inspect(MultiplePrimaryKey).primary_key) + def test_inherited_mapper_args(self): + class Abstract(object): + id = Column(types.Integer(), primary_key=True) + string = Column(types.String()) + number = Column(types.Integer()) + + __global_mapper_args__ = {'column_prefix': '_', + 'order_by': 'number'} + + class Mixin(object): + name = Column(types.String()) + + __global_mapper_args__ = {'column_prefix': '__'} + + class Obj2(Model, Mixin, Abstract): + text = Column(types.Text()) + + __local_mapper_args__ = {'order_by': 'string'} + + self.assertEqual(Obj2.__mapper_args__, + {'column_prefix': '__', 'order_by': 'string'}) + def test_inherited_table_args(self): class Abstract(object): id = Column(types.Integer(), primary_key=True) string = Column(types.String()) number = Column(types.Integer()) - __table_args__ = (Index('idx_abstract_string', 'string'), - Index('idx_abstract_number', 'number'), - {'mysql_foo': 'bar', 'mysql_bar': 'bar'}) + __global_table_args__ = (Index('idx_abstract_string', 'string'), + Index('idx_abstract_number', 'number'), + {'mysql_foo': 'bar', 'mysql_bar': 'bar'}) + + __local_table_args__ = {'not_inherited': 'ignored'} class Mixin(object): name = Column(types.String()) - __table_args__ = (Index('idx_name', 'name'),) + __global_table_args__ = (Index('idx_name', 'name'),) class Obj(Model, Mixin, Abstract): text = Column(types.Text()) - __local_table_args__ = (Index('idx_obj_text', 'text'), - {'mysql_foo': 'foo'}) + + __global_table_args__ = (Index('idx_obj_text', 'text'), + {'mysql_foo': 'foo'}) + + __local_table_args__ = (Index('idx_obj_text2', 'text'), + {'mysql_baz': 'baz'}) self.assertEqual(Obj.__table_args__[-1], - {'mysql_foo': 'foo', 'mysql_bar': 'bar'}) + {'mysql_foo': 'foo', 'mysql_bar': 'bar', + 'mysql_baz': 'baz'}) expected_indexes = ['idx_abstract_string', 'idx_abstract_number', 'idx_name', - 'idx_obj_text'] + 'idx_obj_text', + 'idx_obj_text2'] for i, name in enumerate(expected_indexes): self.assertEqual(Obj.__table_args__[i].name, name) @@ -459,56 +489,48 @@ class AbstractCM(object): string = Column(types.String()) number = Column(types.Integer()) - __table_args__ = (Index('idx_cm_abstract_string', 'string'), - Index('idx_cm_abstract_number', 'number'), - {'mysql_foo': 'bar', 'mysql_bar': 'bar'}) + @classmethod + def __global_table_args__(cls): + return (Index('idx_cm_abstract_string', 'string'), + Index('idx_cm_abstract_number', 'number'), + {'mysql_foo': 'bar', 'mysql_bar': 'bar'}) + + def __local_table_args__(self): + return {'not_inherited': 'ignored'} class MixinCM(object): name = Column(types.String()) - def __table_args__(): + def __global_table_args__(): return (Index('idx_cm_name', 'name'),) class ObjCM(Model, MixinCM, AbstractCM): text = Column(types.Text()) @classmethod - def __local_table_args__(cls): + def __global_table_args__(cls): return (Index('idx_cm_obj_text', 'text'), {'mysql_foo': 'foo'}) + @classmethod + def __local_table_args__(cls): + return (Index('idx_cm_obj_text2', 'text'), + {'mysql_baz': 'baz'}) + self.assertEqual(ObjCM.__table_args__[-1], - {'mysql_foo': 'foo', 'mysql_bar': 'bar'}) + {'mysql_foo': 'foo', 'mysql_bar': 'bar', + 'mysql_baz': 'baz'}) expected_indexes = ['idx_cm_abstract_string', 'idx_cm_abstract_number', 'idx_cm_name', - 'idx_cm_obj_text'] + 'idx_cm_obj_text', + 'idx_cm_obj_text2'] for i, name in enumerate(expected_indexes): self.assertEqual(ObjCM.__table_args__[i].name, name) self.assertIsInstance(ObjCM.__table_args__[i], Index) - def test_inherited_mapper_args(self): - class Abstract(object): - id = Column(types.Integer(), primary_key=True) - string = Column(types.String()) - number = Column(types.Integer()) - - __mapper_args__ = {'column_prefix': '_', 'order_by': 'number'} - - class Mixin(object): - name = Column(types.String()) - - __mapper_args__ = {'column_prefix': '__'} - - class Obj2(Model, Mixin, Abstract): - text = Column(types.Text()) - __local_mapper_args__ = {'order_by': 'string'} - - self.assertEqual(Obj2.__mapper_args__, - {'column_prefix': '__', 'order_by': 'string'}) - def test_is_modified(self): record = Foo.get(1) self.assertEqual(record.is_modified(), False)