Skip to content
This repository
Browse code

Adding some db-specific functionality and getting tests passing under

postgres, woo!
  • Loading branch information...
commit ca7bd0b92e2d00b58295c6947e71d0d2248beb7c 1 parent e04ca26
Charles Leifer authored October 02, 2012

Showing 2 changed files with 88 additions and 15 deletions. Show diff stats Hide diff stats

  1. 22  peewee.py
  2. 81  tests.py
22  peewee.py
@@ -857,7 +857,7 @@ def field_sql(self, field):
857 857
             parts.append('REFERENCES %s (%s)' % ref_mc)
858 858
             parts.append('%(cascade)s%(extra)s')
859 859
         elif field.sequence:
860  
-            parts.append('DEFAULT NEXTVAL(%s)' % self.quote(field.sequence))
  860
+            parts.append("DEFAULT NEXTVAL('%s')" % self.quote(field.sequence))
861 861
         return ' '.join(p % attrs for p in parts)
862 862
 
863 863
     def create_table(self, model_class, safe=False):
@@ -1484,10 +1484,20 @@ def create_index(self, model_class, fields, unique=False):
1484 1484
     def create_foreign_key(self, model_class, field):
1485 1485
         return self.create_index(model_class, [field], field.unique)
1486 1486
 
  1487
+    def create_sequence(self, seq):
  1488
+        if self.sequences:
  1489
+            qc = self.get_compiler()
  1490
+            return self.execute_sql(qc.create_sequence(seq))
  1491
+
1487 1492
     def drop_table(self, model_class, fail_silently=False):
1488 1493
         qc = self.get_compiler()
1489 1494
         return self.execute_sql(qc.drop_table(model_class, fail_silently))
1490 1495
 
  1496
+    def drop_sequence(self, seq):
  1497
+        if self.sequences:
  1498
+            qc = self.get_compiler()
  1499
+            return self.execute_sql(qc.drop_sequence(seq))
  1500
+
1491 1501
     def transaction(self):
1492 1502
         return transaction(self)
1493 1503
 
@@ -1532,13 +1542,14 @@ def get_tables(self):
1532 1542
 class PostgresqlDatabase(Database):
1533 1543
     field_overrides = {
1534 1544
         'bigint': 'BIGINT',
1535  
-        'boolean': 'BOOLEAN',
  1545
+        'bool': 'BOOLEAN',
1536 1546
         'datetime': 'TIMESTAMP',
1537 1547
         'decimal': 'NUMERIC',
1538 1548
         'double': 'DOUBLE PRECISION',
1539 1549
         'primary_key': 'SERIAL',
1540 1550
     }
1541 1551
     for_update = True
  1552
+    interpolation = '%s'
1542 1553
     reserved_tables = ['user']
1543 1554
     sequences = True
1544 1555
 
@@ -1550,10 +1561,10 @@ def _connect(self, database, **kwargs):
1550 1561
     def last_insert_id(self, cursor, model):
1551 1562
         seq = model._meta.primary_key.sequence
1552 1563
         if seq:
1553  
-            cursor.execute_sql("SELECT CURRVAL('\"%s\"')" % (seq))
  1564
+            cursor.execute("SELECT CURRVAL('\"%s\"')" % (seq))
1554 1565
             return cursor.fetchone()[0]
1555 1566
         elif model._meta.auto_increment:
1556  
-            cursor.execute_sql("SELECT CURRVAL('\"%s_%s_seq\"')" % (
  1567
+            cursor.execute("SELECT CURRVAL('\"%s_%s_seq\"')" % (
1557 1568
                 model._meta.db_table, model._meta.primary_key.db_column))
1558 1569
             return cursor.fetchone()[0]
1559 1570
 
@@ -1601,6 +1612,7 @@ class MySQLDatabase(Database):
1601 1612
         'text': 'LONGTEXT',
1602 1613
     }
1603 1614
     for_update_support = True
  1615
+    interpolation = '%s'
1604 1616
     op_overrides = {OP_LIKE: 'LIKE BINARY'}
1605 1617
     quote_char = '`'
1606 1618
     subquery_delete_same_table = False
@@ -1785,7 +1797,7 @@ def __new__(cls, name, bases, attrs):
1785 1797
             primary_key.add_to_class(cls, 'id')
1786 1798
 
1787 1799
         cls._meta.primary_key = primary_key
1788  
-        cls._meta.auto_increment = isinstance(primary_key, PrimaryKeyField)
  1800
+        cls._meta.auto_increment = isinstance(primary_key, PrimaryKeyField) or primary_key.sequence
1789 1801
         if not cls._meta.db_table:
1790 1802
             cls._meta.db_table = re.sub('[^\w]+', '_', cls.__name__.lower())
1791 1803
 
81  tests.py
@@ -159,9 +159,18 @@ class DBBlog(TestModel):
159 159
     title = CharField(db_column='db_title')
160 160
     user = ForeignKeyField(DBUser, db_column='db_user')
161 161
 
  162
+class SeqModelA(TestModel):
  163
+    id = IntegerField(primary_key=True, sequence='just_testing_seq')
  164
+    num = IntegerField()
  165
+
  166
+class SeqModelB(TestModel):
  167
+    id = IntegerField(primary_key=True, sequence='just_testing_seq')
  168
+    other_num = IntegerField()
  169
+
162 170
 
163 171
 MODELS = [User, Blog, Comment, Relationship, NullModel, UniqueModel, OrderedModel, Category, UserCategory,
164  
-          NonIntModel, NonIntRelModel, DBUser, DBBlog]
  172
+          NonIntModel, NonIntRelModel, DBUser, DBBlog, SeqModelA, SeqModelB]
  173
+INT = test_db.interpolation
165 174
 
166 175
 def drop_tables(only=None):
167 176
     for model in reversed(MODELS):
@@ -539,6 +548,9 @@ def setUp(self):
539 548
         drop_tables(self.requires)
540 549
         create_tables(self.requires)
541 550
 
  551
+    def tearDown(self):
  552
+        drop_tables(self.requires)
  553
+
542 554
     def create_user(self, username):
543 555
         return User.create(username=username)
544 556
 
@@ -697,20 +709,20 @@ def test_raw(self):
697 709
         self.create_users(3)
698 710
 
699 711
         qc = len(self.queries())
700  
-        rq = User.raw('select * from users where username IN (?,?)', 'u1', 'u3')
  712
+        rq = User.raw('select * from users where username IN (%s,%s)' % (INT,INT), 'u1', 'u3')
701 713
         self.assertEqual([u.username for u in rq], ['u1', 'u3'])
702 714
 
703 715
         # iterate again
704 716
         self.assertEqual([u.username for u in rq], ['u1', 'u3'])
705 717
         self.assertEqual(len(self.queries()) - qc, 1)
706 718
 
707  
-        rq = User.raw('select id, username, ? as secret from users where username = ?', 'sh', 'u2')
  719
+        rq = User.raw('select id, username, %s as secret from users where username = %s' % (INT,INT), 'sh', 'u2')
708 720
         self.assertEqual([u.secret for u in rq], ['sh'])
709 721
         self.assertEqual([u.username for u in rq], ['u2'])
710 722
 
711 723
 
712 724
 class ModelAPITestCase(ModelTestCase):
713  
-    requires = [User, Blog, Category]
  725
+    requires = [User, Blog, Category, UserCategory]
714 726
 
715 727
     def test_related_name(self):
716 728
         u1 = self.create_user('u1')
@@ -1378,10 +1390,59 @@ def test_connection_state(self):
1378 1390
         self.assertFalse(test_db.is_closed())
1379 1391
 
1380 1392
 
1381  
-class ForUpdateTestCase(ModelTestCase):
1382  
-    requires = []
1383  
-    # TODO
  1393
+if test_db.for_update:
  1394
+    class ForUpdateTestCase(ModelTestCase):
  1395
+        requires = [User]
1384 1396
 
1385  
-class SequenceTestCase(ModelTestCase):
1386  
-    requires = []
1387  
-    # TODO
  1397
+        def tearDown(self):
  1398
+            test_db.set_autocommit(True)
  1399
+
  1400
+        def test_for_update(self):
  1401
+            u1 = self.create_user('u1')
  1402
+            u2 = self.create_user('u2')
  1403
+            u3 = self.create_user('u3')
  1404
+
  1405
+            test_db.set_autocommit(False)
  1406
+
  1407
+            # select a user for update
  1408
+            users = User.select().where(User.username == 'u1').for_update()
  1409
+            updated = User.update(username='u1_edited').where(User.username == 'u1').execute()
  1410
+            self.assertEqual(updated, 1)
  1411
+
  1412
+            # open up a new connection to the database
  1413
+            new_db = database_class(database_name)
  1414
+
  1415
+            # select the username, it will not register as being updated
  1416
+            res = new_db.execute_sql('select username from users where id = %s;' % u1.id)
  1417
+            username = res.fetchone()[0]
  1418
+            self.assertEqual(username, 'u1')
  1419
+
  1420
+            # committing will cause the lock to be released
  1421
+            test_db.commit()
  1422
+
  1423
+            # now we get the update
  1424
+            res = new_db.execute_sql('select username from users where id = %s;' % u1.id)
  1425
+            username = res.fetchone()[0]
  1426
+            self.assertEqual(username, 'u1_edited')
  1427
+
  1428
+elif TEST_VERBOSITY > 0:
  1429
+    print 'Skipping "for update" tests'
  1430
+
  1431
+if test_db.sequences:
  1432
+    class SequenceTestCase(ModelTestCase):
  1433
+        requires = [SeqModelA, SeqModelB]
  1434
+
  1435
+        def test_sequence_shared(self):
  1436
+            a1 = SeqModelA.create(num=1)
  1437
+            a2 = SeqModelA.create(num=2)
  1438
+            b1 = SeqModelB.create(other_num=101)
  1439
+            b2 = SeqModelB.create(other_num=102)
  1440
+            a3 = SeqModelA.create(num=3)
  1441
+
  1442
+            self.assertEqual(a1.id, a2.id - 1)
  1443
+            self.assertEqual(a2.id, b1.id - 1)
  1444
+            self.assertEqual(b1.id, b2.id - 1)
  1445
+            self.assertEqual(b2.id, a3.id - 1)
  1446
+
  1447
+elif TEST_VERBOSITY > 0:
  1448
+    print 'Skipping "sequence" tests'

0 notes on commit ca7bd0b

Please sign in to comment.
Something went wrong with that request. Please try again.