Skip to content

Commit

Permalink
should_set_tablename() doesn't handle declared_attr.
Browse files Browse the repository at this point in the history
  • Loading branch information
seth-p committed Apr 25, 2015
1 parent fba26a6 commit 20ba678
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 51 deletions.
16 changes: 10 additions & 6 deletions alchy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,6 @@ class ModelMeta(DeclarativeMeta):
or :attr:`ModelBase.__events__`.
"""
def __new__(mcs, name, bases, dct):
# Determine if should set __tablename__.
if should_set_tablename(bases, dct):
# Set to underscore version of class name.
dct['__tablename__'] = camelcase_to_underscore(name)

# Set __events__ to expected default so that it's updatable when
# initializing. E.g. if class definition sets __events__=None but
# defines decorated events, then we want the final __events__ attribute
Expand All @@ -57,7 +52,16 @@ def __new__(mcs, name, bases, dct):
dct['__bind_key__'] = base['__bind_key__']
break

return DeclarativeMeta.__new__(mcs, name, bases, dct)
cls = DeclarativeMeta.__new__(mcs, name, bases, dct)

# Determine if should set __tablename__.
# This is being done after DeclarativeMeta.__new__()
# as the class is needed to accommodate @declared_attr columns.
if should_set_tablename(cls):
# Set to underscore version of class name.
cls.__tablename__ = camelcase_to_underscore(name)

return cls

def __init__(cls, name, bases, dct):
DeclarativeMeta.__init__(cls, name, bases, dct)
Expand Down
125 changes: 82 additions & 43 deletions alchy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
"""

import re
import warnings
from collections import Iterable

from sqlalchemy import Column
from sqlalchemy.ext.declarative import AbstractConcreteBase
from sqlalchemy.exc import SAWarning
from sqlalchemy.ext.declarative import AbstractConcreteBase, declared_attr

from ._compat import string_types, iteritems, itervalues, classmethod_func
from ._compat import string_types, itervalues, iteritems, classmethod_func


__all__ = [
'is_sequence',
'has_primary_key',
'camelcase_to_underscore',
'iterflatten',
'flatten'
Expand All @@ -31,13 +32,6 @@ def is_sequence(obj):
not isinstance(obj, dict))


def has_primary_key(metadict):
"""Check if meta class' dict object has a primary key defined."""
return any(column.primary_key
for attr, column in iteritems(metadict)
if isinstance(column, Column))


def camelcase_to_underscore(string):
"""Convert string from ``CamelCase`` to ``under_score``."""
return re.sub('((?<=[a-z0-9])[A-Z]|(?<!_)(?!^)[A-Z](?=[a-z]))', r'_\1',
Expand Down Expand Up @@ -93,6 +87,26 @@ def get_mapper_class(model, field):
return mapper_class(getattr(model, field))


def get_concrete_value(obj,
cls,
check_classmethod=False,
check_callable=False):
"""Returns a 'concrete' form of obj.
If obj is a declared_attr, it is evaluated on cls.
If obj is a classmethod and check_classmethod is True,
it is evaluated on cls.
If obj is callable and check_callable is True, it is evaluated.
"""
if isinstance(obj, declared_attr):
return obj.fget(cls)
elif check_classmethod and isinstance(obj, classmethod):
return classmethod_func(obj)(cls)
elif check_callable and callable(obj):
return obj()
else:
return obj


def merge_declarative_args(cls, global_config_key, local_config_key):
"""Merge declarative args for class `cls`
identified by `global_config_key` and `local_config_key`
Expand All @@ -109,10 +123,10 @@ def merge_declarative_args(cls, global_config_key, local_config_key):
if not obj:
continue

if isinstance(obj, classmethod):
obj = classmethod_func(obj)(cls)
elif callable(obj):
obj = obj()
obj = get_concrete_value(obj,
cls,
check_classmethod=True,
check_callable=True)

if isinstance(obj, dict):
kargs.update(obj)
Expand All @@ -130,7 +144,7 @@ def merge_declarative_args(cls, global_config_key, local_config_key):
return (args, kargs)


def should_set_tablename(bases, dct):
def should_set_tablename(cls):
"""Check what values are set by a class and its bases to determine if a
tablename should be automatically generated.
Expand All @@ -156,38 +170,63 @@ def should_set_tablename(bases, dct):
:return: True if tablename should be set
"""

dct = cls.__dict__

if '__tablename__' in dct or '__table__' in dct or '__abstract__' in dct:
return False

if AbstractConcreteBase in bases:
if AbstractConcreteBase in cls.__bases__:
return False

if has_primary_key(dct):
return True

if '__mapper_args__' in dct:
is_concrete = dct.get('__mapper_args__', {}).get('concrete', False)
else:
is_concrete = dct.get('__global_mapper_args__', {}).get('concrete',
False)
is_concrete = dct.get('__local_mapper_args__', {}).get('concrete',
is_concrete)

names_to_ignore = set(dct.keys())
names_to_ignore.add('query')

for base in bases:
if (not is_concrete) and (hasattr(base, '__tablename__') or
hasattr(base, '__table__')):
return False

for name in dir(base):
if not (name in names_to_ignore or
(name.startswith('__') and name.endswith('__'))):
attr = getattr(base, name)
if getattr(attr, 'primary_key', False):
return True
else:
names_to_ignore.add(name)
def is_primary_key_column(obj):
obj = get_concrete_value(obj, cls)
return isinstance(obj, Column) and obj.primary_key

with warnings.catch_warnings():
warnings.filterwarnings("ignore",
message='Unmanaged access of '
'declarative attribute .*',
category=SAWarning)

for name, value in iteritems(dct):
if is_primary_key_column(value):
return True

if '__mapper_args__' in dct:
mapper_args = get_concrete_value(dct.get('__mapper_args__', {}),
cls,
check_classmethod=True,
check_callable=True)
is_concrete = mapper_args.get('concrete', False)
else:
mapper_args = get_concrete_value(dct.get('__global_mapper_args__',
{}),
cls,
check_classmethod=True,
check_callable=True)
is_concrete = mapper_args.get('concrete', False)
mapper_args = get_concrete_value(dct.get('__local_mapper_args__',
{}),
cls,
check_classmethod=True,
check_callable=True)
is_concrete = mapper_args.get('concrete', is_concrete)

names_to_ignore = set(dct.keys())
names_to_ignore.add('query')

for base in cls.__bases__:
if (not is_concrete) and (hasattr(base, '__tablename__') or
hasattr(base, '__table__')):
return False

for k in base.mro():
for name, value in iteritems(k.__dict__):
if not (name in names_to_ignore or
(name.startswith('__') and name.endswith('__'))):
if is_primary_key_column(value):
return True
else:
names_to_ignore.add(name)

return False
59 changes: 57 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@

from sqlalchemy import orm, Column, types, inspect, Index
from sqlalchemy.orm.exc import UnmappedClassError
from sqlalchemy import orm, Column, types, inspect, Index, ForeignKey
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import relationship
from sqlalchemy.orm.exc import UnmappedClassError

from alchy import query

Expand Down Expand Up @@ -560,3 +562,56 @@ def test_should_set_tablename(self):
self.assertEqual(getattr(KKK, '__tablename__'), 'kkk')
self.assertEqual(hasattr(LLL, '__tablename__'), False)
self.assertEqual(getattr(MMM, '__tablename__'), 'mmm')

def test_should_set_tablename_declared_attr(self):
class IntColumnMixin(object):
@declared_attr
def mixed_in_column(cls):
return Column(types.Integer(),
primary_key=cls.__name__.endswith('Primary'))

# MixedInTablePrimary.mixed_in_column.primary_key is True,
# and so __tablename__ should be generated
class MixedInTablePrimary(IntColumnMixin, Model):
pass

self.assertEqual(getattr(MixedInTablePrimary, '__tablename__'),
'mixed_in_table_primary')

def get_MixedInTableNoPrimaryKey():
# MixedInTableNoPrimaryKey.mixed_in_column.primary_key is False,
# so no __tablename__ should be generated
class MixedInTableNoPrimaryKey(IntColumnMixin, Model):
pass

return MixedInTableNoPrimaryKey

self.assertRaises(InvalidRequestError, get_MixedInTableNoPrimaryKey)

# test that can handle mixed-in @declared_attr relationships
class RelationshipMixin(object):
@declared_attr
def foreign_id(cls):
return Column(types.Integer(),
ForeignKey(MixedInTablePrimary.mixed_in_column),
primary_key=cls.__name__.endswith('Primary'))

@declared_attr
def foreign(cls):
return relationship(MixedInTablePrimary,
foreign_keys=[cls.foreign_id])

class MixedInRelationshipPrimary(RelationshipMixin, Model):
pass

self.assertEqual(getattr(MixedInRelationshipPrimary, '__tablename__'),
'mixed_in_relationship_primary')

class MixedInRelAndDeclaredAttrPrimKey(RelationshipMixin, Model):
@declared_attr
def local_primary_key(cls):
return Column(types.Integer(), primary_key=True)

self.assertEqual(getattr(MixedInRelAndDeclaredAttrPrimKey,
'__tablename__'),
'mixed_in_rel_and_declared_attr_prim_key')

0 comments on commit 20ba678

Please sign in to comment.