Skip to content
This repository
Browse code

Adding a concurrency test case

  • Loading branch information...
commit e04ca26cf4020c5b61202c4ebb78189825dd86e5 1 parent 3597298
Charles Leifer authored October 02, 2012

Showing 1 changed file with 44 additions and 2 deletions. Show diff stats Hide diff stats

  1. 46  tests.py
46  tests.py
@@ -4,6 +4,7 @@
4 4
 import datetime
5 5
 import logging
6 6
 import os
  7
+import Queue
7 8
 import unittest
8 9
 from decimal import Decimal
9 10
 
@@ -1291,8 +1292,49 @@ def do_will_succeed2():
1291 1292
 
1292 1293
 
1293 1294
 class ConcurrencyTestCase(ModelTestCase):
1294  
-    requires = []
1295  
-    # TODO
  1295
+    requires = [User]
  1296
+
  1297
+    def setUp(self):
  1298
+        self._orig_db = test_db
  1299
+        User._meta.database = database_class(database_name, threadlocals=True)
  1300
+        super(ConcurrencyTestCase, self).setUp()
  1301
+
  1302
+    def tearDown(self):
  1303
+        User._meta.database = self._orig_db
  1304
+        super(ConcurrencyTestCase, self).tearDown()
  1305
+
  1306
+    def test_multiple_writers(self):
  1307
+        def create_user_thread(low, hi):
  1308
+            for i in range(low, hi):
  1309
+                User.create(username='u%d' % i)
  1310
+            User._meta.database.close()
  1311
+
  1312
+        threads = []
  1313
+
  1314
+        for i in range(5):
  1315
+            threads.append(threading.Thread(target=create_user_thread, args=(i*10, i * 10 + 10)))
  1316
+
  1317
+        [t.start() for t in threads]
  1318
+        [t.join() for t in threads]
  1319
+
  1320
+        self.assertEqual(User.select().count(), 50)
  1321
+
  1322
+    def test_multiple_readers(self):
  1323
+        data_queue = Queue.Queue()
  1324
+
  1325
+        def reader_thread(q, num):
  1326
+            for i in range(num):
  1327
+                data_queue.put(User.select().count())
  1328
+
  1329
+        threads = []
  1330
+
  1331
+        for i in range(5):
  1332
+            threads.append(threading.Thread(target=reader_thread, args=(data_queue, 20)))
  1333
+
  1334
+        [t.start() for t in threads]
  1335
+        [t.join() for t in threads]
  1336
+
  1337
+        self.assertEqual(data_queue.qsize(), 100)
1296 1338
 
1297 1339
 
1298 1340
 class ModelInheritanceTestCase(BasePeeweeTestCase):

0 notes on commit e04ca26

Please sign in to comment.
Something went wrong with that request. Please try again.