Skip to content

Commit

Permalink
Merge pull request #18 from seth-p/local_global_args
Browse files Browse the repository at this point in the history
Reworked merge_declarative_args() in terms of __{global,local}_{mapper,table}_args__.
  • Loading branch information
dgilland committed Mar 24, 2015
2 parents 28ed3f1 + 26bce8e commit b3175a0
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 66 deletions.
26 changes: 14 additions & 12 deletions alchy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
"""

from sqlalchemy import inspect, orm
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
from sqlalchemy.ext.declarative import (
declarative_base,
DeclarativeMeta,
declared_attr,
)

from . import query, events
from .utils import (
Expand All @@ -11,7 +15,7 @@
camelcase_to_underscore,
get_mapper_class,
merge_mapper_args,
merge_table_args
merge_table_args,
)
from ._compat import iteritems

Expand Down Expand Up @@ -70,16 +74,6 @@ def __init__(cls, name, bases, dct):

base_dcts = [dct] + [base.__dict__ for base in bases]

# Merge __mapper_args__ from all base classes.
__mapper_args__ = merge_mapper_args(cls, base_dcts)
if __mapper_args__:
cls.__mapper_args__ = __mapper_args__

# Merge __table_args__ from all base classes.
__table_args__ = merge_table_args(cls, base_dcts)
if __table_args__:
cls.__table_args__ = __table_args__


class ModelBase(object):
"""Base class for creating a declarative base for models.
Expand Down Expand Up @@ -131,6 +125,14 @@ class User(Model):
query_class = query.QueryModel
query = None

@declared_attr
def __mapper_args__(cls): # pylint: disable=no-self-argument
return merge_mapper_args(cls)

@declared_attr
def __table_args__(cls): # pylint: disable=no-self-argument
return merge_table_args(cls)

def __init__(self, *args, **kargs):
"""Initialize model instance by calling :meth:`update`."""
self.update(*args, **kargs)
Expand Down
35 changes: 16 additions & 19 deletions alchy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,15 @@ def get_mapper_class(model, field):
return mapper_class(getattr(model, field))


def merge_declarative_args(cls, base_dcts, config_key):
"""Given a list of base dicts, merge declarative args identified by
`config_key` into a single configuration object.
def merge_declarative_args(cls, global_config_key, local_config_key):
"""Merge declarative args for class `cls`
identified by `global_config_key` and `local_config_key`
into a consolidated (tuple, dict).
"""
configs = [base.get(config_key) for base in reversed(base_dcts)]
configs = [base.__dict__.get(global_config_key)
for base in reversed(cls.mro())]
configs.append(cls.__dict__.get(local_config_key))

args = []
kargs = {}

Expand Down Expand Up @@ -125,31 +129,24 @@ def merge_declarative_args(cls, base_dcts, config_key):
return (args, kargs)


def merge_mapper_args(cls, base_dcts):
def merge_mapper_args(cls):
"""Merge `__mapper_args__` from all base dictionaries and
`__local_mapper_args__` from first base into single inherited object.
"""
_, kargs = merge_declarative_args(cls, base_dcts, '__mapper_args__')
_, local_kargs = merge_declarative_args(cls,
base_dcts[:1],
'__local_mapper_args__')

kargs.update(local_kargs)
_, kargs = merge_declarative_args(cls,
'__global_mapper_args__',
'__local_mapper_args__')

return kargs


def merge_table_args(cls, base_dcts):
def merge_table_args(cls):
"""Merge `__table_args__` from all base dictionaries and
`__local_table_args__` from first base into single inherited object.
"""
args, kargs = merge_declarative_args(cls, base_dcts, '__table_args__')
local_args, local_kargs = merge_declarative_args(cls,
base_dcts[:1],
'__local_table_args__')

args = unique(args + local_args)
kargs.update(local_kargs)
args, kargs = merge_declarative_args(cls,
'__global_table_args__',
'__local_table_args__')

# Append kargs onto end of args to adhere to SQLAlchemy requirements.
args.append(kargs)
Expand Down
92 changes: 57 additions & 35 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,33 +421,63 @@ def test_multiple_primary_keys(self):
self.assertEqual(MultiplePrimaryKey.primary_key(),
inspect(MultiplePrimaryKey).primary_key)

def test_inherited_mapper_args(self):
class Abstract(object):
id = Column(types.Integer(), primary_key=True)
string = Column(types.String())
number = Column(types.Integer())

__global_mapper_args__ = {'column_prefix': '_',
'order_by': 'number'}

class Mixin(object):
name = Column(types.String())

__global_mapper_args__ = {'column_prefix': '__'}

class Obj2(Model, Mixin, Abstract):
text = Column(types.Text())

__local_mapper_args__ = {'order_by': 'string'}

self.assertEqual(Obj2.__mapper_args__,
{'column_prefix': '__', 'order_by': 'string'})

def test_inherited_table_args(self):
class Abstract(object):
id = Column(types.Integer(), primary_key=True)
string = Column(types.String())
number = Column(types.Integer())

__table_args__ = (Index('idx_abstract_string', 'string'),
Index('idx_abstract_number', 'number'),
{'mysql_foo': 'bar', 'mysql_bar': 'bar'})
__global_table_args__ = (Index('idx_abstract_string', 'string'),
Index('idx_abstract_number', 'number'),
{'mysql_foo': 'bar', 'mysql_bar': 'bar'})

__local_table_args__ = {'not_inherited': 'ignored'}

class Mixin(object):
name = Column(types.String())

__table_args__ = (Index('idx_name', 'name'),)
__global_table_args__ = (Index('idx_name', 'name'),)

class Obj(Model, Mixin, Abstract):
text = Column(types.Text())
__local_table_args__ = (Index('idx_obj_text', 'text'),
{'mysql_foo': 'foo'})

__global_table_args__ = (Index('idx_obj_text', 'text'),
{'mysql_foo': 'foo'})

__local_table_args__ = (Index('idx_obj_text2', 'text'),
{'mysql_baz': 'baz'})

self.assertEqual(Obj.__table_args__[-1],
{'mysql_foo': 'foo', 'mysql_bar': 'bar'})
{'mysql_foo': 'foo', 'mysql_bar': 'bar',
'mysql_baz': 'baz'})

expected_indexes = ['idx_abstract_string',
'idx_abstract_number',
'idx_name',
'idx_obj_text']
'idx_obj_text',
'idx_obj_text2']

for i, name in enumerate(expected_indexes):
self.assertEqual(Obj.__table_args__[i].name, name)
Expand All @@ -459,56 +489,48 @@ class AbstractCM(object):
string = Column(types.String())
number = Column(types.Integer())

__table_args__ = (Index('idx_cm_abstract_string', 'string'),
Index('idx_cm_abstract_number', 'number'),
{'mysql_foo': 'bar', 'mysql_bar': 'bar'})
@classmethod
def __global_table_args__(cls):
return (Index('idx_cm_abstract_string', 'string'),
Index('idx_cm_abstract_number', 'number'),
{'mysql_foo': 'bar', 'mysql_bar': 'bar'})

def __local_table_args__(self):
return {'not_inherited': 'ignored'}

class MixinCM(object):
name = Column(types.String())

def __table_args__():
def __global_table_args__():
return (Index('idx_cm_name', 'name'),)

class ObjCM(Model, MixinCM, AbstractCM):
text = Column(types.Text())

@classmethod
def __local_table_args__(cls):
def __global_table_args__(cls):
return (Index('idx_cm_obj_text', 'text'),
{'mysql_foo': 'foo'})

@classmethod
def __local_table_args__(cls):
return (Index('idx_cm_obj_text2', 'text'),
{'mysql_baz': 'baz'})

self.assertEqual(ObjCM.__table_args__[-1],
{'mysql_foo': 'foo', 'mysql_bar': 'bar'})
{'mysql_foo': 'foo', 'mysql_bar': 'bar',
'mysql_baz': 'baz'})

expected_indexes = ['idx_cm_abstract_string',
'idx_cm_abstract_number',
'idx_cm_name',
'idx_cm_obj_text']
'idx_cm_obj_text',
'idx_cm_obj_text2']

for i, name in enumerate(expected_indexes):
self.assertEqual(ObjCM.__table_args__[i].name, name)
self.assertIsInstance(ObjCM.__table_args__[i], Index)

def test_inherited_mapper_args(self):
class Abstract(object):
id = Column(types.Integer(), primary_key=True)
string = Column(types.String())
number = Column(types.Integer())

__mapper_args__ = {'column_prefix': '_', 'order_by': 'number'}

class Mixin(object):
name = Column(types.String())

__mapper_args__ = {'column_prefix': '__'}

class Obj2(Model, Mixin, Abstract):
text = Column(types.Text())
__local_mapper_args__ = {'order_by': 'string'}

self.assertEqual(Obj2.__mapper_args__,
{'column_prefix': '__', 'order_by': 'string'})

def test_is_modified(self):
record = Foo.get(1)
self.assertEqual(record.is_modified(), False)
Expand Down

0 comments on commit b3175a0

Please sign in to comment.