Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

rewrite tablename generation again #541

Merged
merged 1 commit into from Sep 26, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
96 changes: 54 additions & 42 deletions flask_sqlalchemy/__init__.py
Expand Up @@ -30,7 +30,7 @@
from sqlalchemy.orm.exc import UnmappedClassError
from sqlalchemy.orm.session import Session as SessionBase

from ._compat import iteritems, itervalues, string_types, xrange
from ._compat import itervalues, string_types, xrange

__version__ = '2.2.1'

Expand Down Expand Up @@ -551,31 +551,39 @@ def get_engine(self):


def _should_set_tablename(cls):
"""Traverse the model's MRO. If a primary key column is found before a
table or tablename, then a new tablename should be generated.

This supports:

* Joined table inheritance without explicitly naming sub-models.
* Single table inheritance.
* Inheriting from mixins or abstract models.

:param cls: model to check
:return: True if tablename should be set
"""Determine whether ``__tablename__`` should be automatically generated
for a model.

* If no class in the MRO sets a name, one should be generated.
* If a declared attr is found, it should be used instead.
* If a name is found, it should be used if the class is a mixin, otherwise
one should be generated.
* Abstract models should not have one generated.

Later, :meth:`._BoundDeclarativeMeta.__table_cls__` will determine if the
model looks like single or joined-table inheritance. If no primary key is
found, the name will be unset.
"""
if (
cls.__dict__.get('__abstract__', False)
or not any(isinstance(b, DeclarativeMeta) for b in cls.__mro__[1:])
):
return False

for base in cls.__mro__:
d = base.__dict__
if '__tablename__' not in base.__dict__:
continue

if '__tablename__' in d or '__table__' in d:
if isinstance(base.__dict__['__tablename__'], declared_attr):
return False

for name, obj in iteritems(d):
if isinstance(obj, declared_attr):
obj = getattr(cls, name)
return not (
base is cls
or base.__dict__.get('__abstract__', False)
or not isinstance(base, DeclarativeMeta)
)

if isinstance(obj, sqlalchemy.Column) and obj.primary_key:
return True
return True


def camel_to_snake_case(name):
Expand All @@ -591,20 +599,36 @@ def _join(match):


class _BoundDeclarativeMeta(DeclarativeMeta):
def __new__(cls, name, bases, d):
# if tablename is set explicitly, move it to the cache attribute so
# that future subclasses still have auto behavior
if '__tablename__' in d:
d['_cached_tablename'] = d.pop('__tablename__')
def __init__(cls, name, bases, d):
if _should_set_tablename(cls):
cls.__tablename__ = camel_to_snake_case(cls.__name__)

bind_key = (
d.pop('__bind_key__', None)
or getattr(cls, '__bind_key__', None)
)

return DeclarativeMeta.__new__(cls, name, bases, d)
super(_BoundDeclarativeMeta, cls).__init__(name, bases, d)

def __init__(self, name, bases, d):
bind_key = d.pop('__bind_key__', None) or getattr(self, '__bind_key__', None)
DeclarativeMeta.__init__(self, name, bases, d)
if bind_key is not None and hasattr(cls, '__table__'):
cls.__table__.info['bind_key'] = bind_key

if bind_key is not None and hasattr(self, '__table__'):
self.__table__.info['bind_key'] = bind_key
def __table_cls__(cls, *args, **kwargs):
"""This is called by SQLAlchemy during mapper setup. It determines the
final table object that the model will use.

If no primary key is found, that indicates single-table inheritance,
so no table will be created and ``__tablename__`` will be unset.
"""
for arg in args:
if (
(isinstance(arg, sqlalchemy.Column) and arg.primary_key)
or isinstance(arg, sqlalchemy.PrimaryKeyConstraint)
):
return sqlalchemy.Table(*args, **kwargs)

if '__tablename__' in cls.__dict__:
del cls.__tablename__


def get_state(app):
Expand Down Expand Up @@ -638,18 +662,6 @@ class Model(object):
#: Equivalent to ``db.session.query(Model)`` unless :attr:`query_class` has been changed.
query = None

_cached_tablename = None

@declared_attr
def __tablename__(cls):
if (
'_cached_tablename' not in cls.__dict__ and
_should_set_tablename(cls)
):
cls._cached_tablename = camel_to_snake_case(cls.__name__)

return cls._cached_tablename


class SQLAlchemy(object):
"""This class is used to control the SQLAlchemy integration to one
Expand Down
92 changes: 81 additions & 11 deletions tests/test_table_name.py
@@ -1,3 +1,5 @@
import inspect

from sqlalchemy.ext.declarative import declared_attr


Expand Down Expand Up @@ -25,6 +27,7 @@ class Duck(db.Model):
class Mallard(Duck):
pass

assert '__tablename__' not in Mallard.__dict__
assert Mallard.__tablename__ == 'duck'


Expand All @@ -39,8 +42,10 @@ class Donald(Duck):
assert Donald.__tablename__ == 'donald'


def test_mixin_name(db):
"""Primary key provided by mixin should still allow model to set tablename."""
def test_mixin_id(db):
"""Primary key provided by mixin should still allow model to set
tablename.
"""
class Base(object):
id = db.Column(db.Integer, primary_key=True)

Expand All @@ -51,28 +56,57 @@ class Duck(Base, db.Model):
assert Duck.__tablename__ == 'duck'


def test_mixin_attr(db):
"""A declared attr tablename will be used down multiple levels of
inheritance.
"""
class Mixin(object):
@declared_attr
def __tablename__(cls):
return cls.__name__.upper()

class Bird(Mixin, db.Model):
id = db.Column(db.Integer, primary_key=True)

class Duck(Bird):
# object reference
id = db.Column(db.ForeignKey(Bird.id), primary_key=True)

class Mallard(Duck):
# string reference
id = db.Column(db.ForeignKey('DUCK.id'), primary_key=True)

assert Bird.__tablename__ == 'BIRD'
assert Duck.__tablename__ == 'DUCK'
assert Mallard.__tablename__ == 'MALLARD'


def test_abstract_name(db):
"""Abstract model should not set a name. Subclass should set a name."""
"""Abstract model should not set a name. Subclass should set a name."""
class Base(db.Model):
__abstract__ = True
id = db.Column(db.Integer, primary_key=True)

class Duck(Base):
pass

assert Base.__tablename__ == 'base'
assert '__tablename__' not in Base.__dict__
assert Duck.__tablename__ == 'duck'


def test_complex_inheritance(db):
"""Joined table inheritance, but the new primary key is provided by a mixin, not directly on the class."""
"""Joined table inheritance, but the new primary key is provided by a
mixin, not directly on the class.
"""
class Duck(db.Model):
id = db.Column(db.Integer, primary_key=True)

class IdMixin(object):
@declared_attr
def id(cls):
return db.Column(db.Integer, db.ForeignKey(Duck.id), primary_key=True)
return db.Column(
db.Integer, db.ForeignKey(Duck.id), primary_key=True
)

class RubberDuck(IdMixin, Duck):
pass
Expand All @@ -81,18 +115,55 @@ class RubberDuck(IdMixin, Duck):


def test_manual_name(db):
"""Setting a manual name prevents generation for the immediate model. A
name is generated for joined but not single-table inheritance.
"""
class Duck(db.Model):
__tablename__ = 'DUCK'
id = db.Column(db.Integer, primary_key=True)
type = db.Column(db.String)

__mapper_args__ = {
'polymorphic_on': type
}

class Daffy(Duck):
id = db.Column(db.Integer, db.ForeignKey(Duck.id), primary_key=True)

__mapper_args__ = {
'polymorphic_identity': 'Warner'
}

class Donald(Duck):
__mapper_args__ = {
'polymorphic_identity': 'Disney'
}

assert Duck.__tablename__ == 'DUCK'
assert Daffy.__tablename__ == 'daffy'
assert '__tablename__' not in Donald.__dict__
assert Donald.__tablename__ == 'DUCK'
# polymorphic condition for single-table query
assert 'WHERE "DUCK".type' in str(Donald.query)


def test_primary_constraint(db):
"""Primary key will be picked up from table args."""
class Duck(db.Model):
id = db.Column(db.Integer)

__table_args__ = (
db.PrimaryKeyConstraint(id),
)

assert Duck.__table__ is not None
assert Duck.__tablename__ == 'duck'


def test_no_access_to_class_property(db):
"""Ensure the implementation doesn't access class properties or declared
attrs while inspecting the unmapped model.
"""
class class_property(object):
def __init__(self, f):
self.f = f
Expand All @@ -106,14 +177,13 @@ class Duck(db.Model):
class ns(object):
accessed = False

# Since there's no id provided by the following model,
# _should_set_tablename will scan all attributes. If it's working
# properly, it won't access the class property, but will access the
# declared_attr.

class Witch(Duck):
@declared_attr
def is_duck(self):
# declared attrs will be accessed during mapper configuration,
# but make sure they're not accessed before that
info = inspect.getouterframes(inspect.currentframe())[2]
assert info[3] != '_should_set_tablename'
ns.accessed = True

@class_property
Expand Down