diff --git a/alchy/_compat.py b/alchy/_compat.py index da8f5d9..8b0bb3c 100644 --- a/alchy/_compat.py +++ b/alchy/_compat.py @@ -102,3 +102,10 @@ def __exit__(self, *args): BROKEN_PYPY_CTXMGR_EXIT = True except AssertionError: pass + + +# Define classmethod_func(f) to retrieve the unbound function of classmethod f +if sys.version_info[:2] >= (2, 7): + def classmethod_func(f): return f.__func__ +else: + def classmethod_func(f): return f.__get__(1).im_func diff --git a/alchy/model.py b/alchy/model.py index 79bf863..a92d4ae 100644 --- a/alchy/model.py +++ b/alchy/model.py @@ -2,11 +2,7 @@ """ from sqlalchemy import inspect, orm -from sqlalchemy.ext.declarative import ( - declarative_base, - DeclarativeMeta, - declared_attr -) +from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta from . import query, events from .utils import ( @@ -14,8 +10,8 @@ has_primary_key, camelcase_to_underscore, get_mapper_class, - merge_declarative_args, - unique + merge_mapper_args, + merge_table_args ) from ._compat import iteritems @@ -72,6 +68,18 @@ def __init__(cls, name, bases, dct): events.register(cls, dct) + base_dcts = [dct] + [base.__dict__ for base in bases] + + # Merge __mapper_args__ from all base classes. + __mapper_args__ = merge_mapper_args(cls, base_dcts) + if __mapper_args__: + cls.__mapper_args__ = __mapper_args__ + + # Merge __table_args__ from all base classes. + __table_args__ = merge_table_args(cls, base_dcts) + if __table_args__: + cls.__table_args__ = __table_args__ + class ModelBase(object): """Base class for creating a declarative base for models. @@ -123,33 +131,6 @@ class User(Model): query_class = query.QueryModel query = None - @declared_attr - def __table_args__(cls): # pylint: disable=no-self-argument - # pylint: disable=no-member - args, kargs = merge_declarative_args(cls.__bases__, '__table_args__') - local_args, local_kargs = merge_declarative_args( - [cls], '__local_table_args__') - - args = unique(args + local_args) - kargs.update(local_kargs) - - # Append kargs onto end of args to adhere to SQLAlchemy requirements. - args.append(kargs) - - return tuple(args) - - @declared_attr - def __mapper_args__(cls): # pylint: disable=no-self-argument - # pylint: disable=no-member - # NOTE: Mapper args are only allowed to be a dict so we ignore `args`. - _, kargs = merge_declarative_args(cls.__bases__, '__mapper_args__') - _, local_kargs = merge_declarative_args([cls], - '__local_mapper_args__') - - kargs.update(local_kargs) - - return kargs - def __init__(self, *args, **kargs): """Initialize model instance by calling :meth:`update`.""" self.update(*args, **kargs) diff --git a/alchy/utils.py b/alchy/utils.py index daef94c..d8f38d6 100644 --- a/alchy/utils.py +++ b/alchy/utils.py @@ -9,7 +9,7 @@ from sqlalchemy import Column -from ._compat import string_types, iteritems +from ._compat import string_types, iteritems, classmethod_func __all__ = [ @@ -92,12 +92,11 @@ def get_mapper_class(model, field): return mapper_class(getattr(model, field)) -def merge_declarative_args(base_classes, config_key): - """Given a list of base classes, merge declarative args identified by +def merge_declarative_args(cls, base_dcts, config_key): + """Given a list of base dicts, merge declarative args identified by `config_key` into a single configuration object. """ - configs = [getattr(base, config_key, None) - for base in reversed(base_classes)] + configs = [base.get(config_key) for base in reversed(base_dcts)] args = [] kargs = {} @@ -105,7 +104,9 @@ def merge_declarative_args(base_classes, config_key): if not obj: continue - if callable(obj): + if isinstance(obj, classmethod): + obj = classmethod_func(obj)(cls) + elif callable(obj): obj = obj() if isinstance(obj, dict): @@ -122,3 +123,35 @@ def merge_declarative_args(base_classes, config_key): args = unique(args) return (args, kargs) + + +def merge_mapper_args(cls, base_dcts): + """Merge `__mapper_args__` from all base dictionaries and + `__local_mapper_args__` from first base into single inherited object. + """ + _, kargs = merge_declarative_args(cls, base_dcts, '__mapper_args__') + _, local_kargs = merge_declarative_args(cls, + base_dcts[:1], + '__local_mapper_args__') + + kargs.update(local_kargs) + + return kargs + + +def merge_table_args(cls, base_dcts): + """Merge `__table_args__` from all base dictionaries and + `__local_table_args__` from first base into single inherited object. + """ + args, kargs = merge_declarative_args(cls, base_dcts, '__table_args__') + local_args, local_kargs = merge_declarative_args(cls, + base_dcts[:1], + '__local_table_args__') + + args = unique(args + local_args) + kargs.update(local_kargs) + + # Append kargs onto end of args to adhere to SQLAlchemy requirements. + args.append(kargs) + + return tuple(args) diff --git a/tests/test_model.py b/tests/test_model.py index 5ed823d..9cc67eb 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -466,7 +466,8 @@ class AbstractCM(object): class MixinCM(object): name = Column(types.String()) - __table_args__ = (Index('idx_cm_name', 'name'),) + def __table_args__(): + return (Index('idx_cm_name', 'name'),) class ObjCM(Model, MixinCM, AbstractCM): text = Column(types.Text())