Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Adding some db-specific functionality and getting tests passing under

postgres, woo!
  • Loading branch information...
commit ca7bd0b92e2d00b58295c6947e71d0d2248beb7c 1 parent e04ca26
@coleifer authored
Showing with 88 additions and 15 deletions.
  1. +17 −5 peewee.py
  2. +71 −10 tests.py
View
22 peewee.py
@@ -857,7 +857,7 @@ def field_sql(self, field):
parts.append('REFERENCES %s (%s)' % ref_mc)
parts.append('%(cascade)s%(extra)s')
elif field.sequence:
- parts.append('DEFAULT NEXTVAL(%s)' % self.quote(field.sequence))
+ parts.append("DEFAULT NEXTVAL('%s')" % self.quote(field.sequence))
return ' '.join(p % attrs for p in parts)
def create_table(self, model_class, safe=False):
@@ -1484,10 +1484,20 @@ def create_index(self, model_class, fields, unique=False):
def create_foreign_key(self, model_class, field):
return self.create_index(model_class, [field], field.unique)
+ def create_sequence(self, seq):
+ if self.sequences:
+ qc = self.get_compiler()
+ return self.execute_sql(qc.create_sequence(seq))
+
def drop_table(self, model_class, fail_silently=False):
qc = self.get_compiler()
return self.execute_sql(qc.drop_table(model_class, fail_silently))
+ def drop_sequence(self, seq):
+ if self.sequences:
+ qc = self.get_compiler()
+ return self.execute_sql(qc.drop_sequence(seq))
+
def transaction(self):
return transaction(self)
@@ -1532,13 +1542,14 @@ def get_tables(self):
class PostgresqlDatabase(Database):
field_overrides = {
'bigint': 'BIGINT',
- 'boolean': 'BOOLEAN',
+ 'bool': 'BOOLEAN',
'datetime': 'TIMESTAMP',
'decimal': 'NUMERIC',
'double': 'DOUBLE PRECISION',
'primary_key': 'SERIAL',
}
for_update = True
+ interpolation = '%s'
reserved_tables = ['user']
sequences = True
@@ -1550,10 +1561,10 @@ def _connect(self, database, **kwargs):
def last_insert_id(self, cursor, model):
seq = model._meta.primary_key.sequence
if seq:
- cursor.execute_sql("SELECT CURRVAL('\"%s\"')" % (seq))
+ cursor.execute("SELECT CURRVAL('\"%s\"')" % (seq))
return cursor.fetchone()[0]
elif model._meta.auto_increment:
- cursor.execute_sql("SELECT CURRVAL('\"%s_%s_seq\"')" % (
+ cursor.execute("SELECT CURRVAL('\"%s_%s_seq\"')" % (
model._meta.db_table, model._meta.primary_key.db_column))
return cursor.fetchone()[0]
@@ -1601,6 +1612,7 @@ class MySQLDatabase(Database):
'text': 'LONGTEXT',
}
for_update_support = True
+ interpolation = '%s'
op_overrides = {OP_LIKE: 'LIKE BINARY'}
quote_char = '`'
subquery_delete_same_table = False
@@ -1785,7 +1797,7 @@ def __new__(cls, name, bases, attrs):
primary_key.add_to_class(cls, 'id')
cls._meta.primary_key = primary_key
- cls._meta.auto_increment = isinstance(primary_key, PrimaryKeyField)
+ cls._meta.auto_increment = isinstance(primary_key, PrimaryKeyField) or primary_key.sequence
if not cls._meta.db_table:
cls._meta.db_table = re.sub('[^\w]+', '_', cls.__name__.lower())
View
81 tests.py
@@ -159,9 +159,18 @@ class DBBlog(TestModel):
title = CharField(db_column='db_title')
user = ForeignKeyField(DBUser, db_column='db_user')
+class SeqModelA(TestModel):
+ id = IntegerField(primary_key=True, sequence='just_testing_seq')
+ num = IntegerField()
+
+class SeqModelB(TestModel):
+ id = IntegerField(primary_key=True, sequence='just_testing_seq')
+ other_num = IntegerField()
+
MODELS = [User, Blog, Comment, Relationship, NullModel, UniqueModel, OrderedModel, Category, UserCategory,
- NonIntModel, NonIntRelModel, DBUser, DBBlog]
+ NonIntModel, NonIntRelModel, DBUser, DBBlog, SeqModelA, SeqModelB]
+INT = test_db.interpolation
def drop_tables(only=None):
for model in reversed(MODELS):
@@ -539,6 +548,9 @@ def setUp(self):
drop_tables(self.requires)
create_tables(self.requires)
+ def tearDown(self):
+ drop_tables(self.requires)
+
def create_user(self, username):
return User.create(username=username)
@@ -697,20 +709,20 @@ def test_raw(self):
self.create_users(3)
qc = len(self.queries())
- rq = User.raw('select * from users where username IN (?,?)', 'u1', 'u3')
+ rq = User.raw('select * from users where username IN (%s,%s)' % (INT,INT), 'u1', 'u3')
self.assertEqual([u.username for u in rq], ['u1', 'u3'])
# iterate again
self.assertEqual([u.username for u in rq], ['u1', 'u3'])
self.assertEqual(len(self.queries()) - qc, 1)
- rq = User.raw('select id, username, ? as secret from users where username = ?', 'sh', 'u2')
+ rq = User.raw('select id, username, %s as secret from users where username = %s' % (INT,INT), 'sh', 'u2')
self.assertEqual([u.secret for u in rq], ['sh'])
self.assertEqual([u.username for u in rq], ['u2'])
class ModelAPITestCase(ModelTestCase):
- requires = [User, Blog, Category]
+ requires = [User, Blog, Category, UserCategory]
def test_related_name(self):
u1 = self.create_user('u1')
@@ -1378,10 +1390,59 @@ def test_connection_state(self):
self.assertFalse(test_db.is_closed())
-class ForUpdateTestCase(ModelTestCase):
- requires = []
- # TODO
+if test_db.for_update:
+ class ForUpdateTestCase(ModelTestCase):
+ requires = [User]
-class SequenceTestCase(ModelTestCase):
- requires = []
- # TODO
+ def tearDown(self):
+ test_db.set_autocommit(True)
+
+ def test_for_update(self):
+ u1 = self.create_user('u1')
+ u2 = self.create_user('u2')
+ u3 = self.create_user('u3')
+
+ test_db.set_autocommit(False)
+
+ # select a user for update
+ users = User.select().where(User.username == 'u1').for_update()
+ updated = User.update(username='u1_edited').where(User.username == 'u1').execute()
+ self.assertEqual(updated, 1)
+
+ # open up a new connection to the database
+ new_db = database_class(database_name)
+
+ # select the username, it will not register as being updated
+ res = new_db.execute_sql('select username from users where id = %s;' % u1.id)
+ username = res.fetchone()[0]
+ self.assertEqual(username, 'u1')
+
+ # committing will cause the lock to be released
+ test_db.commit()
+
+ # now we get the update
+ res = new_db.execute_sql('select username from users where id = %s;' % u1.id)
+ username = res.fetchone()[0]
+ self.assertEqual(username, 'u1_edited')
+
+elif TEST_VERBOSITY > 0:
+ print 'Skipping "for update" tests'
+
+if test_db.sequences:
+ class SequenceTestCase(ModelTestCase):
+ requires = [SeqModelA, SeqModelB]
+
+ def test_sequence_shared(self):
+ a1 = SeqModelA.create(num=1)
+ a2 = SeqModelA.create(num=2)
+ b1 = SeqModelB.create(other_num=101)
+ b2 = SeqModelB.create(other_num=102)
+ a3 = SeqModelA.create(num=3)
+
+ self.assertEqual(a1.id, a2.id - 1)
+ self.assertEqual(a2.id, b1.id - 1)
+ self.assertEqual(b1.id, b2.id - 1)
+ self.assertEqual(b2.id, a3.id - 1)
+
+elif TEST_VERBOSITY > 0:
+ print 'Skipping "sequence" tests'
Please sign in to comment.
Something went wrong with that request. Please try again.