From a62adf06b7454efed65aac4f0f2b2241aafe465c Mon Sep 17 00:00:00 2001 From: Seth Padowitz Date: Sun, 29 Mar 2015 16:53:29 -0400 Subject: [PATCH] Updated __tablename__-setting logic. --- alchy/model.py | 13 +++---- alchy/utils.py | 61 ++++++++++++++++++++++++++++++- tests/test_model.py | 87 +++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 148 insertions(+), 13 deletions(-) diff --git a/alchy/model.py b/alchy/model.py index c0ea185..a7bfc9a 100644 --- a/alchy/model.py +++ b/alchy/model.py @@ -11,10 +11,10 @@ from . import query, events from .utils import ( is_sequence, - has_primary_key, camelcase_to_underscore, get_mapper_class, merge_declarative_args, + should_set_tablename, ) from ._compat import iteritems @@ -37,12 +37,8 @@ class ModelMeta(DeclarativeMeta): or :attr:`ModelBase.__events__`. """ def __new__(mcs, name, bases, dct): - # Determine if primary key is defined for dct or any of its bases. - base_dcts = [dct] + [base.__dict__ for base in bases] - - if (not dct.get('__tablename__') and - dct.get('__table__') is None and - any([has_primary_key(base) for base in base_dcts])): + # Determine if should set __tablename__. + if should_set_tablename(bases, dct): # Set to underscore version of class name. dct['__tablename__'] = camelcase_to_underscore(name) @@ -55,6 +51,7 @@ def __new__(mcs, name, bases, dct): dct['__events__'] = {} if '__bind_key__' not in dct: + base_dcts = [dct] + [base.__dict__ for base in bases] for base in base_dcts: if '__bind_key__' in base: dct['__bind_key__'] = base['__bind_key__'] @@ -71,8 +68,6 @@ def __init__(cls, name, bases, dct): events.register(cls, dct) - base_dcts = [dct] + [base.__dict__ for base in bases] - class ModelBase(object): """Base class for creating a declarative base for models. diff --git a/alchy/utils.py b/alchy/utils.py index c6ee346..aa27001 100644 --- a/alchy/utils.py +++ b/alchy/utils.py @@ -8,8 +8,9 @@ from collections import Iterable from sqlalchemy import Column +from sqlalchemy.ext.declarative import AbstractConcreteBase -from ._compat import string_types, iteritems, classmethod_func +from ._compat import string_types, iteritems, itervalues, classmethod_func __all__ = [ @@ -127,3 +128,61 @@ def merge_declarative_args(cls, global_config_key, local_config_key): args = unique(args) return (args, kargs) + + +def should_set_tablename(bases, dct): + """Check what values are set by a class and its bases to determine if a + tablename should be automatically generated. + + The class and its bases are checked in order of precedence: the class + itself then each base in the order they were given at class definition. + + Abstract classes do not generate a tablename, although they may have set + or inherited a tablename elsewhere. + + If a class defines a tablename or table, a new one will not be generated. + Otherwise, if the class defines a primary key, a new name will be + generated. + + This supports: + + * Joined table inheritance without explicitly naming sub-models. + * Single table inheritance. + * Concrete table inheritance + * Inheriting from mixins or abstract models. + + :param bases: base classes of new class + :param dct: new class dict + :return: True if tablename should be set + """ + + if '__tablename__' in dct or '__table__' in dct or '__abstract__' in dct: + return False + + if has_primary_key(dct): + return True + + if '__mapper_args__' in dct: + is_concrete = dct.get('__mapper_args__', {}).get('concrete', False) + else: + is_concrete = dct.get('__global_mapper_args__', {}).get('concrete', + False) + is_concrete = dct.get('__local_mapper_args__', {}).get('concrete', + is_concrete) + + for base in bases: + if base is AbstractConcreteBase: + return False + + if (not is_concrete) and (hasattr(base, '__tablename__') or + hasattr(base, '__table__')): + return False + + for name in dir(base): + if not (name in ('query') or + (name.startswith('__') and name.endswith('__'))): + attr = getattr(base, name) + if getattr(attr, 'primary_key', False): + return True + + return False diff --git a/tests/test_model.py b/tests/test_model.py index 337818f..daba6b1 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1,9 +1,9 @@ -import sqlalchemy -from sqlalchemy import orm, Column, types, inspect, Index +from sqlalchemy import orm, Column, types, inspect, Index, ForeignKey from sqlalchemy.orm.exc import UnmappedClassError +from sqlalchemy.ext.declarative import ConcreteBase, AbstractConcreteBase -from alchy import model, query, manager, events +from alchy import query from .base import TestQueryBase from . import fixtures @@ -540,3 +540,84 @@ def test_is_modified(self): record.refresh() self.assertEqual(record.is_modified(), False) + + def test_should_set_tablename(self): + class AAA(Model): + __abstract__ = True + idx = Column(types.Integer(), primary_key=True) + + self.assertEqual(hasattr(AAA, '__tablename__'), False) + + class BBB(AAA): + __abstract__ = True + b_int = Column(types.Integer()) + + self.assertEqual(hasattr(BBB, '__tablename__'), False) + + class CCC(BBB): + c_int = Column(types.Integer()) + + self.assertEqual(getattr(CCC, '__tablename__'), 'ccc') + + # Joined table inheritance + class DDD(CCC): + idx = Column(types.Integer(), ForeignKey(CCC.idx), + primary_key=True) + d_int = Column(types.Integer()) + + self.assertEqual(getattr(DDD, '__tablename__'), 'ddd') + + # Single table inheritance + class EEE(BBB): + idx = Column(types.Integer(), primary_key=True) + e_str = Column(types.String()) + __global_mapper_args__ = {'polymorphic_on': e_str} + + self.assertEqual(getattr(EEE, '__tablename__'), 'eee') + + class FFF(EEE): + f_int = Column(types.Integer()) + __local_mapper_args__ = {'polymorphic_identity': 'eee_subtype_fff'} + + self.assertEqual(getattr(FFF, '__tablename__'), 'eee') + + class FFF2(EEE): + f2_int = Column(types.Integer()) + __mapper_args__ = {'polymorphic_identity': 'eee_subtype_fff2'} + + self.assertEqual(getattr(FFF2, '__tablename__'), 'eee') + + # Concrete table inheritance + class GGG(CCC): + idx = Column(types.Integer(), primary_key=True) + g_int = Column(types.Integer()) + __local_mapper_args__ = {'concrete': True} + + self.assertEqual(getattr(GGG, '__tablename__'), 'ggg') + + # Concrete table inheritance - using ConcreteBase + class HHH(ConcreteBase, BBB): + h_int = Column(types.Integer()) + __local_mapper_args__ = {'polymorphic_on': h_int, 'concrete': True} + + self.assertEqual(getattr(HHH, '__tablename__'), 'hhh') + + class III(HHH): + idx = Column(types.Integer(), primary_key=True) + i_int = Column(types.Integer()) + __mapper_args__ = {'polymorphic_identity': 2, 'concrete': True} + + self.assertEqual(getattr(III, '__tablename__'), 'iii') + + # Concrete table inheritance - using AbstractConcreteBase + class JJJ(AbstractConcreteBase, BBB): + pass + + self.assertEqual(hasattr(JJJ, '__tablename__'), False) + + class KKK(JJJ): + idx = Column(types.Integer(), primary_key=True) + k_int = Column(types.Integer()) + __mapper_args__ = {'polymorphic_identity': 2, 'concrete': True} + + self.assertEqual(getattr(KKK, '__tablename__'), 'kkk')