Permalink
Browse files

Renamed Node.get_database_engine() to Node.get_database_vendor()

  • Loading branch information...
1 parent 1d90b8a commit b883bbb9271e6678afb62bf6424f2de0885968d5 @tabo tabo committed Dec 2, 2012
Showing with 25 additions and 25 deletions.
  1. +3 −0 CHANGES
  2. +16 −17 treebeard/models.py
  3. +6 −8 treebeard/mp_tree.py
View
@@ -28,6 +28,9 @@ Release 2.0 (XXX XX, 2012)
and/or signals.
* Improved translation files, including javascript.
* Fixed Django 1.4 support in the admin.
+* Renamed Node.get_database_engine() to Node.get_database_vendor(). As the name
+ implies, it returns the database vendor instead of the engine used. Treebeard
+ will get the value from Django, but you can subclass the method if needed.
Release 1.61 (Jul 24, 2010)
---------------------------
View
@@ -7,7 +7,7 @@
from functools import reduce
from django.db.models import Q
-from django.db import models, transaction
+from django.db import models, transaction, router, connections
from django.conf import settings
from treebeard.exceptions import InvalidPosition, MissingNodeOrderBy
@@ -16,6 +16,8 @@
class Node(models.Model):
"Node class"
+ _db_vendor = None
+
@classmethod
def add_root(cls, **kwargs): # pragma: no cover
"""
@@ -565,26 +567,23 @@ def _get_serializable_model(cls):
return cls
@classmethod
- def get_database_engine(cls):
+ def get_database_vendor(cls, action):
"""
- Returns the supported database engine used by a treebeard model.
+ Returns the supported database vendor used by a treebeard model when
+ performing read (select) or write (update, insert, delete) operations.
+
+ :param action:
- This will return the default database engine depending on the version
- of Django. If you use something different, like a non-default database,
- you need to override this method and return the correct engine.
+ `read` or `write`
- :returns: postgresql, postgresql_psycopg2, mysql or sqlite3
+ :returns: postgresql, mysql or sqlite
"""
- engine = None
- try:
- engine = settings.DATABASES['default']['ENGINE']
- except (AttributeError, KeyError):
- engine = None
- # the old style settings still work in Django 1.2+ if there is no
- # DATABASES setting
- if engine is None:
- engine = settings.DATABASE_ENGINE
- return engine.split('.')[-1]
+ if cls._db_vendor is None:
+ cls._db_vendor = {
+ 'read': connections[router.db_for_read(cls)].vendor,
+ 'write': connections[router.db_for_write(cls)].vendor
+ }
+ return cls._db_vendor[action]
class Meta:
"Abstract model."
View
@@ -295,7 +295,7 @@ def fix_tree(cls, destructive=False):
# fix the numchild field
vals = ['_' * cls.steplen]
# the cake and sql portability are a lie
- if cls.get_database_engine() == 'mysql':
+ if cls.get_database_vendor('read') == 'mysql':
sql = "SELECT tbn1.path, tbn1.numchild, ("\
"SELECT COUNT(1) "\
"FROM %(table)s AS tbn2 "\
@@ -865,7 +865,7 @@ def _updates_after_move(cls, oldpath, newpath, stmts):
2. update the number of children of parent nodes
"""
if (
- cls.get_database_engine() == 'mysql' and
+ cls.get_database_vendor('write') == 'mysql' and
len(oldpath) != len(newpath)
):
# no words can describe how dumb mysql is
@@ -896,18 +896,19 @@ def _get_sql_newpath_in_branches(cls, oldpath, newpath):
"""
+ vendor = cls.get_database_vendor('write')
sql1 = "UPDATE %s SET" % (
connection.ops.quote_name(cls._meta.db_table), )
# <3 "standard" sql
- if cls.get_database_engine() == 'sqlite3':
+ if vendor == 'sqlite':
# I know that the third argument in SUBSTR (LENGTH(path)) is
# awful, but sqlite fails without it:
# OperationalError: wrong number of arguments to function substr()
# even when the documentation says that 2 arguments are valid:
# http://www.sqlite.org/lang_corefunc.html
sqlpath = "%s||SUBSTR(path, %s, LENGTH(path))"
- elif cls.get_database_engine() == 'mysql':
+ elif vendor == 'mysql':
# hooray for mysql ignoring standards in their default
# configuration!
# to make || work as it should, enable ansi mode
@@ -918,10 +919,7 @@ def _get_sql_newpath_in_branches(cls, oldpath, newpath):
sql2 = ["path=%s" % (sqlpath, )]
vals = [newpath, len(oldpath) + 1]
- if (
- len(oldpath) != len(newpath) and
- cls.get_database_engine() != 'mysql'
- ):
+ if len(oldpath) != len(newpath) and vendor != 'mysql':
# when using mysql, this won't update the depth and it has to be
# done in another query
# doesn't even work with sql_mode='ANSI,TRADITIONAL'

0 comments on commit b883bbb

Please sign in to comment.