Skip to content

Commit

Permalink
Updated __tablename__-setting logic.
Browse files Browse the repository at this point in the history
  • Loading branch information
seth-p committed Mar 30, 2015
1 parent e7ca049 commit f64376f
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 13 deletions.
13 changes: 4 additions & 9 deletions alchy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from . import query, events
from .utils import (
is_sequence,
has_primary_key,
camelcase_to_underscore,
get_mapper_class,
merge_declarative_args,
should_set_tablename,
)
from ._compat import iteritems

Expand All @@ -37,12 +37,8 @@ class ModelMeta(DeclarativeMeta):
or :attr:`ModelBase.__events__`.
"""
def __new__(mcs, name, bases, dct):
# Determine if primary key is defined for dct or any of its bases.
base_dcts = [dct] + [base.__dict__ for base in bases]

if (not dct.get('__tablename__') and
dct.get('__table__') is None and
any([has_primary_key(base) for base in base_dcts])):
# Determine if should set __tablename__.
if should_set_tablename(bases, dct):
# Set to underscore version of class name.
dct['__tablename__'] = camelcase_to_underscore(name)

Expand All @@ -55,6 +51,7 @@ def __new__(mcs, name, bases, dct):
dct['__events__'] = {}

if '__bind_key__' not in dct:
base_dcts = [dct] + [base.__dict__ for base in bases]
for base in base_dcts:
if '__bind_key__' in base:
dct['__bind_key__'] = base['__bind_key__']
Expand All @@ -71,8 +68,6 @@ def __init__(cls, name, bases, dct):

events.register(cls, dct)

base_dcts = [dct] + [base.__dict__ for base in bases]


class ModelBase(object):
"""Base class for creating a declarative base for models.
Expand Down
61 changes: 60 additions & 1 deletion alchy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from collections import Iterable

from sqlalchemy import Column
from sqlalchemy.ext.declarative import AbstractConcreteBase

from ._compat import string_types, iteritems, classmethod_func
from ._compat import string_types, iteritems, itervalues, classmethod_func


__all__ = [
Expand Down Expand Up @@ -127,3 +128,61 @@ def merge_declarative_args(cls, global_config_key, local_config_key):
args = unique(args)

return (args, kargs)


def should_set_tablename(bases, dct):
"""Check what values are set by a class and its bases to determine if a
tablename should be automatically generated.
The class and its bases are checked in order of precedence: the class
itself then each base in the order they were given at class definition.
Abstract classes do not generate a tablename, although they may have set
or inherited a tablename elsewhere.
If a class defines a tablename or table, a new one will not be generated.
Otherwise, if the class defines a primary key, a new name will be
generated.
This supports:
* Joined table inheritance without explicitly naming sub-models.
* Single table inheritance.
* Concrete table inheritance
* Inheriting from mixins or abstract models.
:param bases: base classes of new class
:param dct: new class dict
:return: True if tablename should be set
"""

if '__tablename__' in dct or '__table__' in dct or '__abstract__' in dct:
return False

if has_primary_key(dct):
return True

if '__mapper_args__' in dct:
is_concrete = dct.get('__mapper_args__', {}).get('concrete', False)
else:
is_concrete = dct.get('__global_mapper_args__', {}).get('concrete',
False)
is_concrete = dct.get('__local_mapper_args__', {}).get('concrete',
is_concrete)

for base in bases:
if base is AbstractConcreteBase:
return False

if (not is_concrete) and (hasattr(base, '__tablename__') or
hasattr(base, '__table__')):
return False

for name in dir(base):
if not (name in ('query') or
(name.startswith('__') and name.endswith('__'))):
attr = getattr(base, name)
if getattr(attr, 'primary_key', False):
return True

return False
81 changes: 78 additions & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@

import sqlalchemy
from sqlalchemy import orm, Column, types, inspect, Index
from sqlalchemy import orm, Column, types, inspect, Index, ForeignKey
from sqlalchemy.orm.exc import UnmappedClassError
from sqlalchemy.ext.declarative import ConcreteBase, AbstractConcreteBase

from alchy import model, query, manager, events
from alchy import query

from .base import TestQueryBase
from . import fixtures
Expand Down Expand Up @@ -540,3 +540,78 @@ def test_is_modified(self):

record.refresh()
self.assertEqual(record.is_modified(), False)

def test_should_set_tablename(self):
class AAA(Model):
__abstract__ = True
idx = Column(types.Integer(), primary_key=True)

self.assertEqual(hasattr(AAA, '__tablename__'), False)

class BBB(AAA):
__abstract__ = True
b_int = Column(types.Integer())

self.assertEqual(hasattr(BBB, '__tablename__'), False)

class CCC(BBB):
c_int = Column(types.Integer())

self.assertEqual(getattr(CCC, '__tablename__'), 'ccc')

# Joined table inheritance
class DDD(CCC):
idx = Column(types.Integer(), ForeignKey(CCC.idx),
primary_key=True)
d_int = Column(types.Integer())

self.assertEqual(getattr(DDD, '__tablename__'), 'ddd')

# Single table inheritance
class EEE(BBB):
idx = Column(types.Integer(), primary_key=True)
e_str = Column(types.String())
__global_mapper_args__ = {'polymorphic_on': e_str}

self.assertEqual(getattr(EEE, '__tablename__'), 'eee')

class FFF(EEE):
f_int = Column(types.Integer())
__local_mapper_args__ = {'polymorphic_identity': 'eee_subtype_fff'}

self.assertEqual(getattr(FFF, '__tablename__'), 'eee')

# Concrete table inheritance
class GGG(CCC):
idx = Column(types.Integer(), primary_key=True)
g_int = Column(types.Integer())
__local_mapper_args__ = {'concrete': True}

self.assertEqual(getattr(GGG, '__tablename__'), 'ggg')

# Concrete table inheritance - using ConcreteBase
class HHH(ConcreteBase, BBB):
h_int = Column(types.Integer())
__local_mapper_args__ = {'polymorphic_on': h_int, 'concrete': True}

self.assertEqual(getattr(HHH, '__tablename__'), 'hhh')

class III(HHH):
idx = Column(types.Integer(), primary_key=True)
i_int = Column(types.Integer())
__mapper_args__ = {'polymorphic_identity': 2, 'concrete': True}

self.assertEqual(getattr(III, '__tablename__'), 'iii')

# Concrete table inheritance - using AbstractConcreteBase
class JJJ(AbstractConcreteBase, BBB):
pass

self.assertEqual(hasattr(JJJ, '__tablename__'), False)

class KKK(JJJ):
idx = Column(types.Integer(), primary_key=True)
k_int = Column(types.Integer())
__mapper_args__ = {'polymorphic_identity': 2, 'concrete': True}

self.assertEqual(getattr(KKK, '__tablename__'), 'kkk')

0 comments on commit f64376f

Please sign in to comment.