Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

should_set_tablename() doesn't handle declared_attr. #23

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions alchy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,6 @@ class ModelMeta(DeclarativeMeta):
or :attr:`ModelBase.__events__`.
"""
def __new__(mcs, name, bases, dct):
# Determine if should set __tablename__.
if should_set_tablename(bases, dct):
# Set to underscore version of class name.
dct['__tablename__'] = camelcase_to_underscore(name)

# Set __events__ to expected default so that it's updatable when
# initializing. E.g. if class definition sets __events__=None but
# defines decorated events, then we want the final __events__ attribute
Expand All @@ -57,7 +52,16 @@ def __new__(mcs, name, bases, dct):
dct['__bind_key__'] = base['__bind_key__']
break

return DeclarativeMeta.__new__(mcs, name, bases, dct)
cls = DeclarativeMeta.__new__(mcs, name, bases, dct)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment here about why tablename is being set after call to DeclarativeMeta.__new__.


# Determine if should set __tablename__.
# This is being done after DeclarativeMeta.__new__()
# as the class is needed to accommodate @declared_attr columns.
if should_set_tablename(cls):
# Set to underscore version of class name.
cls.__tablename__ = camelcase_to_underscore(name)

return cls

def __init__(cls, name, bases, dct):
DeclarativeMeta.__init__(cls, name, bases, dct)
Expand Down
125 changes: 82 additions & 43 deletions alchy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@
"""

import re
import warnings
from collections import Iterable

from sqlalchemy import Column
from sqlalchemy.ext.declarative import AbstractConcreteBase
from sqlalchemy.exc import SAWarning
from sqlalchemy.ext.declarative import AbstractConcreteBase, declared_attr

from ._compat import string_types, iteritems, itervalues, classmethod_func
from ._compat import string_types, itervalues, iteritems, classmethod_func


__all__ = [
'is_sequence',
'has_primary_key',
'camelcase_to_underscore',
'iterflatten',
'flatten'
Expand All @@ -31,13 +32,6 @@ def is_sequence(obj):
not isinstance(obj, dict))


def has_primary_key(metadict):
"""Check if meta class' dict object has a primary key defined."""
return any(column.primary_key
for attr, column in iteritems(metadict)
if isinstance(column, Column))


def camelcase_to_underscore(string):
"""Convert string from ``CamelCase`` to ``under_score``."""
return re.sub('((?<=[a-z0-9])[A-Z]|(?<!_)(?!^)[A-Z](?=[a-z]))', r'_\1',
Expand Down Expand Up @@ -93,6 +87,26 @@ def get_mapper_class(model, field):
return mapper_class(getattr(model, field))


def get_concrete_value(obj,
cls,
check_classmethod=False,
check_callable=False):
"""Return a 'concrete' form of obj.
If obj is a declared_attr, it is evaluated on cls.
If obj is a classmethod and check_classmethod is True,
it is evaluated on cls.
If obj is callable and check_callable is True, it is evaluated.
"""
if isinstance(obj, declared_attr):
return obj.fget(cls)
elif check_classmethod and isinstance(obj, classmethod):
return classmethod_func(obj)(cls)
elif check_callable and callable(obj):
return obj()
else:
return obj


def merge_declarative_args(cls, global_config_key, local_config_key):
"""Merge declarative args for class `cls`
identified by `global_config_key` and `local_config_key`
Expand All @@ -109,10 +123,10 @@ def merge_declarative_args(cls, global_config_key, local_config_key):
if not obj:
continue

if isinstance(obj, classmethod):
obj = classmethod_func(obj)(cls)
elif callable(obj):
obj = obj()
obj = get_concrete_value(obj,
cls,
check_classmethod=True,
check_callable=True)

if isinstance(obj, dict):
kargs.update(obj)
Expand All @@ -130,7 +144,7 @@ def merge_declarative_args(cls, global_config_key, local_config_key):
return (args, kargs)


def should_set_tablename(bases, dct):
def should_set_tablename(cls):
"""Check what values are set by a class and its bases to determine if a
tablename should be automatically generated.

Expand All @@ -156,38 +170,63 @@ def should_set_tablename(bases, dct):
:return: True if tablename should be set
"""

dct = cls.__dict__

if '__tablename__' in dct or '__table__' in dct or '__abstract__' in dct:
return False

if AbstractConcreteBase in bases:
if AbstractConcreteBase in cls.__bases__:
return False

if has_primary_key(dct):
return True

if '__mapper_args__' in dct:
is_concrete = dct.get('__mapper_args__', {}).get('concrete', False)
else:
is_concrete = dct.get('__global_mapper_args__', {}).get('concrete',
False)
is_concrete = dct.get('__local_mapper_args__', {}).get('concrete',
is_concrete)

names_to_ignore = set(dct.keys())
names_to_ignore.add('query')

for base in bases:
if (not is_concrete) and (hasattr(base, '__tablename__') or
hasattr(base, '__table__')):
return False

for name in dir(base):
if not (name in names_to_ignore or
(name.startswith('__') and name.endswith('__'))):
attr = getattr(base, name)
if getattr(attr, 'primary_key', False):
return True
else:
names_to_ignore.add(name)
def is_primary_key_column(obj):
obj = get_concrete_value(obj, cls)
return isinstance(obj, Column) and obj.primary_key

with warnings.catch_warnings():
warnings.filterwarnings("ignore",
message='Unmanaged access of '
'declarative attribute .*',
category=SAWarning)

for name, value in iteritems(dct):
if is_primary_key_column(value):
return True

if '__mapper_args__' in dct:
mapper_args = get_concrete_value(dct.get('__mapper_args__', {}),
cls,
check_classmethod=True,
check_callable=True)
is_concrete = mapper_args.get('concrete', False)
else:
mapper_args = get_concrete_value(dct.get('__global_mapper_args__',
{}),
cls,
check_classmethod=True,
check_callable=True)
is_concrete = mapper_args.get('concrete', False)
mapper_args = get_concrete_value(dct.get('__local_mapper_args__',
{}),
cls,
check_classmethod=True,
check_callable=True)
is_concrete = mapper_args.get('concrete', is_concrete)

names_to_ignore = set(dct.keys())
names_to_ignore.add('query')

for base in cls.__bases__:
if (not is_concrete) and (hasattr(base, '__tablename__') or
hasattr(base, '__table__')):
return False

for k in base.mro():
for name, value in iteritems(k.__dict__):
if not (name in names_to_ignore or
(name.startswith('__') and name.endswith('__'))):
if is_primary_key_column(value):
return True
else:
names_to_ignore.add(name)

return False
59 changes: 57 additions & 2 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@

from sqlalchemy import orm, Column, types, inspect, Index
from sqlalchemy.orm.exc import UnmappedClassError
from sqlalchemy import orm, Column, types, inspect, Index, ForeignKey
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import relationship
from sqlalchemy.orm.exc import UnmappedClassError

from alchy import query

Expand Down Expand Up @@ -560,3 +562,56 @@ def test_should_set_tablename(self):
self.assertEqual(getattr(KKK, '__tablename__'), 'kkk')
self.assertEqual(hasattr(LLL, '__tablename__'), False)
self.assertEqual(getattr(MMM, '__tablename__'), 'mmm')

def test_should_set_tablename_declared_attr(self):
class IntColumnMixin(object):
@declared_attr
def mixed_in_column(cls):
return Column(types.Integer(),
primary_key=cls.__name__.endswith('Primary'))

# MixedInTablePrimary.mixed_in_column.primary_key is True,
# and so __tablename__ should be generated
class MixedInTablePrimary(IntColumnMixin, Model):
pass

self.assertEqual(getattr(MixedInTablePrimary, '__tablename__'),
'mixed_in_table_primary')

def get_MixedInTableNoPrimaryKey():
# MixedInTableNoPrimaryKey.mixed_in_column.primary_key is False,
# so no __tablename__ should be generated
class MixedInTableNoPrimaryKey(IntColumnMixin, Model):
pass

return MixedInTableNoPrimaryKey

self.assertRaises(InvalidRequestError, get_MixedInTableNoPrimaryKey)

# test that can handle mixed-in @declared_attr relationships
class RelationshipMixin(object):
@declared_attr
def foreign_id(cls):
return Column(types.Integer(),
ForeignKey(MixedInTablePrimary.mixed_in_column),
primary_key=cls.__name__.endswith('Primary'))

@declared_attr
def foreign(cls):
return relationship(MixedInTablePrimary,
foreign_keys=[cls.foreign_id])

class MixedInRelationshipPrimary(RelationshipMixin, Model):
pass

self.assertEqual(getattr(MixedInRelationshipPrimary, '__tablename__'),
'mixed_in_relationship_primary')

class MixedInRelAndDeclaredAttrPrimKey(RelationshipMixin, Model):
@declared_attr
def local_primary_key(cls):
return Column(types.Integer(), primary_key=True)

self.assertEqual(getattr(MixedInRelAndDeclaredAttrPrimKey,
'__tablename__'),
'mixed_in_rel_and_declared_attr_prim_key')