Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reworked merge_declarative_args() in terms of __{global,local}_{mapper,table}_args__. #18

Merged
merged 1 commit into from
Mar 24, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions alchy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
"""

from sqlalchemy import inspect, orm
from sqlalchemy.ext.declarative import declarative_base, DeclarativeMeta
from sqlalchemy.ext.declarative import (
declarative_base,
DeclarativeMeta,
declared_attr,
)

from . import query, events
from .utils import (
Expand All @@ -11,7 +15,7 @@
camelcase_to_underscore,
get_mapper_class,
merge_mapper_args,
merge_table_args
merge_table_args,
)
from ._compat import iteritems

Expand Down Expand Up @@ -70,16 +74,6 @@ def __init__(cls, name, bases, dct):

base_dcts = [dct] + [base.__dict__ for base in bases]

# Merge __mapper_args__ from all base classes.
__mapper_args__ = merge_mapper_args(cls, base_dcts)
if __mapper_args__:
cls.__mapper_args__ = __mapper_args__

# Merge __table_args__ from all base classes.
__table_args__ = merge_table_args(cls, base_dcts)
if __table_args__:
cls.__table_args__ = __table_args__


class ModelBase(object):
"""Base class for creating a declarative base for models.
Expand Down Expand Up @@ -131,6 +125,14 @@ class User(Model):
query_class = query.QueryModel
query = None

@declared_attr
def __mapper_args__(cls): # pylint: disable=no-self-argument
return merge_mapper_args(cls)

@declared_attr
def __table_args__(cls): # pylint: disable=no-self-argument
return merge_table_args(cls)

def __init__(self, *args, **kargs):
"""Initialize model instance by calling :meth:`update`."""
self.update(*args, **kargs)
Expand Down
35 changes: 16 additions & 19 deletions alchy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,15 @@ def get_mapper_class(model, field):
return mapper_class(getattr(model, field))


def merge_declarative_args(cls, base_dcts, config_key):
"""Given a list of base dicts, merge declarative args identified by
`config_key` into a single configuration object.
def merge_declarative_args(cls, global_config_key, local_config_key):
"""Merge declarative args for class `cls`
identified by `global_config_key` and `local_config_key`
into a consolidated (tuple, dict).
"""
configs = [base.get(config_key) for base in reversed(base_dcts)]
configs = [base.__dict__.get(global_config_key)
for base in reversed(cls.mro())]
configs.append(cls.__dict__.get(local_config_key))

args = []
kargs = {}

Expand Down Expand Up @@ -125,31 +129,24 @@ def merge_declarative_args(cls, base_dcts, config_key):
return (args, kargs)


def merge_mapper_args(cls, base_dcts):
def merge_mapper_args(cls):
"""Merge `__mapper_args__` from all base dictionaries and
`__local_mapper_args__` from first base into single inherited object.
"""
_, kargs = merge_declarative_args(cls, base_dcts, '__mapper_args__')
_, local_kargs = merge_declarative_args(cls,
base_dcts[:1],
'__local_mapper_args__')

kargs.update(local_kargs)
_, kargs = merge_declarative_args(cls,
'__global_mapper_args__',
'__local_mapper_args__')

return kargs


def merge_table_args(cls, base_dcts):
def merge_table_args(cls):
"""Merge `__table_args__` from all base dictionaries and
`__local_table_args__` from first base into single inherited object.
"""
args, kargs = merge_declarative_args(cls, base_dcts, '__table_args__')
local_args, local_kargs = merge_declarative_args(cls,
base_dcts[:1],
'__local_table_args__')

args = unique(args + local_args)
kargs.update(local_kargs)
args, kargs = merge_declarative_args(cls,
'__global_table_args__',
'__local_table_args__')

# Append kargs onto end of args to adhere to SQLAlchemy requirements.
args.append(kargs)
Expand Down
92 changes: 57 additions & 35 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,33 +421,63 @@ def test_multiple_primary_keys(self):
self.assertEqual(MultiplePrimaryKey.primary_key(),
inspect(MultiplePrimaryKey).primary_key)

def test_inherited_mapper_args(self):
class Abstract(object):
id = Column(types.Integer(), primary_key=True)
string = Column(types.String())
number = Column(types.Integer())

__global_mapper_args__ = {'column_prefix': '_',
'order_by': 'number'}

class Mixin(object):
name = Column(types.String())

__global_mapper_args__ = {'column_prefix': '__'}

class Obj2(Model, Mixin, Abstract):
text = Column(types.Text())

__local_mapper_args__ = {'order_by': 'string'}

self.assertEqual(Obj2.__mapper_args__,
{'column_prefix': '__', 'order_by': 'string'})

def test_inherited_table_args(self):
class Abstract(object):
id = Column(types.Integer(), primary_key=True)
string = Column(types.String())
number = Column(types.Integer())

__table_args__ = (Index('idx_abstract_string', 'string'),
Index('idx_abstract_number', 'number'),
{'mysql_foo': 'bar', 'mysql_bar': 'bar'})
__global_table_args__ = (Index('idx_abstract_string', 'string'),
Index('idx_abstract_number', 'number'),
{'mysql_foo': 'bar', 'mysql_bar': 'bar'})

__local_table_args__ = {'not_inherited': 'ignored'}

class Mixin(object):
name = Column(types.String())

__table_args__ = (Index('idx_name', 'name'),)
__global_table_args__ = (Index('idx_name', 'name'),)

class Obj(Model, Mixin, Abstract):
text = Column(types.Text())
__local_table_args__ = (Index('idx_obj_text', 'text'),
{'mysql_foo': 'foo'})

__global_table_args__ = (Index('idx_obj_text', 'text'),
{'mysql_foo': 'foo'})

__local_table_args__ = (Index('idx_obj_text2', 'text'),
{'mysql_baz': 'baz'})

self.assertEqual(Obj.__table_args__[-1],
{'mysql_foo': 'foo', 'mysql_bar': 'bar'})
{'mysql_foo': 'foo', 'mysql_bar': 'bar',
'mysql_baz': 'baz'})

expected_indexes = ['idx_abstract_string',
'idx_abstract_number',
'idx_name',
'idx_obj_text']
'idx_obj_text',
'idx_obj_text2']

for i, name in enumerate(expected_indexes):
self.assertEqual(Obj.__table_args__[i].name, name)
Expand All @@ -459,56 +489,48 @@ class AbstractCM(object):
string = Column(types.String())
number = Column(types.Integer())

__table_args__ = (Index('idx_cm_abstract_string', 'string'),
Index('idx_cm_abstract_number', 'number'),
{'mysql_foo': 'bar', 'mysql_bar': 'bar'})
@classmethod
def __global_table_args__(cls):
return (Index('idx_cm_abstract_string', 'string'),
Index('idx_cm_abstract_number', 'number'),
{'mysql_foo': 'bar', 'mysql_bar': 'bar'})

def __local_table_args__(self):
return {'not_inherited': 'ignored'}

class MixinCM(object):
name = Column(types.String())

def __table_args__():
def __global_table_args__():
return (Index('idx_cm_name', 'name'),)

class ObjCM(Model, MixinCM, AbstractCM):
text = Column(types.Text())

@classmethod
def __local_table_args__(cls):
def __global_table_args__(cls):
return (Index('idx_cm_obj_text', 'text'),
{'mysql_foo': 'foo'})

@classmethod
def __local_table_args__(cls):
return (Index('idx_cm_obj_text2', 'text'),
{'mysql_baz': 'baz'})

self.assertEqual(ObjCM.__table_args__[-1],
{'mysql_foo': 'foo', 'mysql_bar': 'bar'})
{'mysql_foo': 'foo', 'mysql_bar': 'bar',
'mysql_baz': 'baz'})

expected_indexes = ['idx_cm_abstract_string',
'idx_cm_abstract_number',
'idx_cm_name',
'idx_cm_obj_text']
'idx_cm_obj_text',
'idx_cm_obj_text2']

for i, name in enumerate(expected_indexes):
self.assertEqual(ObjCM.__table_args__[i].name, name)
self.assertIsInstance(ObjCM.__table_args__[i], Index)

def test_inherited_mapper_args(self):
class Abstract(object):
id = Column(types.Integer(), primary_key=True)
string = Column(types.String())
number = Column(types.Integer())

__mapper_args__ = {'column_prefix': '_', 'order_by': 'number'}

class Mixin(object):
name = Column(types.String())

__mapper_args__ = {'column_prefix': '__'}

class Obj2(Model, Mixin, Abstract):
text = Column(types.Text())
__local_mapper_args__ = {'order_by': 'string'}

self.assertEqual(Obj2.__mapper_args__,
{'column_prefix': '__', 'order_by': 'string'})

def test_is_modified(self):
record = Foo.get(1)
self.assertEqual(record.is_modified(), False)
Expand Down