Skip to content
This repository
Browse code

Merge branch 'feature/extras'

Conflicts:
	tests.py
  • Loading branch information...
commit 781bbc7d5a33277d6e965e58e6666139a1562b44 2 parents ffceea9 + 2145061
Charles Leifer authored June 08, 2012
0  extras/__init__.py
No changes.
71  extras/signals.py
... ...
@@ -0,0 +1,71 @@
  1
+from peewee import Model as _Model
  2
+
  3
+
  4
+class Signal(object):
  5
+    def __init__(self):
  6
+        self._flush()
  7
+
  8
+    def connect(self, receiver, name=None, sender=None):
  9
+        name = name or receiver.__name__
  10
+        if name not in self._receivers:
  11
+            self._receivers[name] = (receiver, sender)
  12
+            self._receiver_list.append(name)
  13
+        else:
  14
+            raise ValueError('receiver named %s already connected' % name)
  15
+
  16
+    def disconnect(self, receiver=None, name=None):
  17
+        if receiver:
  18
+            name = receiver.__name__
  19
+        if name:
  20
+            del self._receivers[name]
  21
+            self._receiver_list.remove(name)
  22
+        else:
  23
+            raise ValueError('a receiver or a name must be provided')
  24
+
  25
+    def send(self, instance, *args, **kwargs):
  26
+        sender = type(instance)
  27
+        responses = []
  28
+        for name in self._receiver_list:
  29
+            r, s = self._receivers[name]
  30
+            if s is None or sender is s:
  31
+                responses.append((r, r(sender, instance, *args, **kwargs)))
  32
+        return responses
  33
+
  34
+    def _flush(self):
  35
+        self._receivers = {}
  36
+        self._receiver_list = []
  37
+
  38
+
  39
+pre_save = Signal()
  40
+post_save = Signal()
  41
+pre_delete = Signal()
  42
+post_delete = Signal()
  43
+pre_init = Signal()
  44
+post_init = Signal()
  45
+
  46
+
  47
+class Model(_Model):
  48
+    def __init__(self, *args, **kwargs):
  49
+        super(Model, self).__init__(*args, **kwargs)
  50
+        pre_init.send(self)
  51
+
  52
+    def prepared(self):
  53
+        super(Model, self).prepared()
  54
+        post_init.send(self)
  55
+
  56
+    def save(self, *args, **kwargs):
  57
+        created = not bool(self.get_pk())
  58
+        pre_save.send(self, created=created)
  59
+        super(Model, self).save(*args, **kwargs)
  60
+        post_save.send(self, created=created)
  61
+
  62
+    def delete_instance(self, *args, **kwargs):
  63
+        pre_delete.send(self)
  64
+        super(Model, self).delete_instance(*args, **kwargs)
  65
+        post_delete.send(self)
  66
+
  67
+def connect(signal, name=None, sender=None):
  68
+    def decorator(fn):
  69
+        signal.connect(fn, name, sender)
  70
+        return fn
  71
+    return decorator
139  extras/tests.py
... ...
@@ -0,0 +1,139 @@
  1
+import unittest
  2
+
  3
+from peewee import *
  4
+import signals
  5
+
  6
+
  7
+db = SqliteDatabase(':memory:')
  8
+
  9
+class BaseSignalModel(signals.Model):
  10
+    class Meta:
  11
+        database = db
  12
+
  13
+class ModelA(BaseSignalModel):
  14
+    a = CharField()
  15
+
  16
+class ModelB(BaseSignalModel):
  17
+    b = CharField()
  18
+
  19
+
  20
+class SignalsTestCase(unittest.TestCase):
  21
+    def setUp(self):
  22
+        ModelA.create_table(True)
  23
+        ModelB.create_table(True)
  24
+
  25
+    def tearDown(self):
  26
+        ModelA.drop_table()
  27
+        ModelB.drop_table()
  28
+        signals.pre_save._flush()
  29
+        signals.post_save._flush()
  30
+        signals.pre_delete._flush()
  31
+        signals.post_delete._flush()
  32
+        signals.pre_init._flush()
  33
+        signals.post_init._flush()
  34
+
  35
+    def test_pre_save(self):
  36
+        state = []
  37
+
  38
+        @signals.connect(signals.pre_save)
  39
+        def pre_save(sender, instance, created):
  40
+            state.append((sender, instance, instance.get_pk(), created))
  41
+        m = ModelA()
  42
+        m.save()
  43
+        self.assertEqual(state, [(ModelA, m, None, True)])
  44
+
  45
+        m.save()
  46
+        self.assertTrue(m.id is not None)
  47
+        self.assertEqual(state[-1], (ModelA, m, m.id, False))
  48
+
  49
+    def test_post_save(self):
  50
+        state = []
  51
+
  52
+        @signals.connect(signals.post_save)
  53
+        def post_save(sender, instance, created):
  54
+            state.append((sender, instance, instance.get_pk(), created))
  55
+        m = ModelA()
  56
+        m.save()
  57
+
  58
+        self.assertTrue(m.id is not None)
  59
+        self.assertEqual(state, [(ModelA, m, m.id, True)])
  60
+
  61
+        m.save()
  62
+        self.assertEqual(state[-1], (ModelA, m, m.id, False))
  63
+
  64
+    def test_pre_delete(self):
  65
+        state = []
  66
+
  67
+        m = ModelA()
  68
+        m.save()
  69
+
  70
+        @signals.connect(signals.pre_delete)
  71
+        def pre_delete(sender, instance):
  72
+            state.append((sender, instance, ModelA.select().count()))
  73
+        m.delete_instance()
  74
+        self.assertEqual(state, [(ModelA, m, 1)])
  75
+
  76
+    def test_post_delete(self):
  77
+        state = []
  78
+
  79
+        m = ModelA()
  80
+        m.save()
  81
+
  82
+        @signals.connect(signals.post_delete)
  83
+        def post_delete(sender, instance):
  84
+            state.append((sender, instance, ModelA.select().count()))
  85
+        m.delete_instance()
  86
+        self.assertEqual(state, [(ModelA, m, 0)])
  87
+
  88
+    def test_pre_init(self):
  89
+        state = []
  90
+
  91
+        m = ModelA(a='a')
  92
+        m.save()
  93
+
  94
+        @signals.connect(signals.pre_init)
  95
+        def pre_init(sender, instance):
  96
+            state.append((sender, instance.a))
  97
+
  98
+        ModelA.get()
  99
+        self.assertEqual(state, [(ModelA, None)])
  100
+
  101
+    def test_post_init(self):
  102
+        state = []
  103
+
  104
+        m = ModelA(a='a')
  105
+        m.save()
  106
+
  107
+        @signals.connect(signals.post_init)
  108
+        def post_init(sender, instance):
  109
+            state.append((sender, instance.a))
  110
+
  111
+        ModelA.get()
  112
+        self.assertEqual(state, [(ModelA, 'a')])
  113
+
  114
+    def test_sender(self):
  115
+        state = []
  116
+
  117
+        @signals.connect(signals.post_save, sender=ModelA)
  118
+        def post_save(sender, instance, created):
  119
+            state.append(instance)
  120
+
  121
+        m = ModelA.create()
  122
+        self.assertEqual(state, [m])
  123
+
  124
+        m2 = ModelB.create()
  125
+        self.assertEqual(state, [m])
  126
+
  127
+    def test_connect_disconnect(self):
  128
+        state = []
  129
+
  130
+        @signals.connect(signals.post_save, sender=ModelA)
  131
+        def post_save(sender, instance, created):
  132
+            state.append(instance)
  133
+
  134
+        m = ModelA.create()
  135
+        self.assertEqual(state, [m])
  136
+
  137
+        signals.post_save.disconnect(post_save)
  138
+        m2 = ModelA.create()
  139
+        self.assertEqual(state, [m])
48  runtests.py
... ...
@@ -1,16 +1,50 @@
1 1
 #!/usr/bin/env python
  2
+import optparse
2 3
 import os
  4
+import sys
3 5
 import unittest
4 6
 
5 7
 
6  
-def collect():
7  
-    start_dir = os.path.abspath(os.path.dirname(__file__))
8  
-    return unittest.defaultTestLoader.discover(start_dir)
  8
+def runtests(module, verbosity):
  9
+    suite = unittest.TestLoader().loadTestsFromModule(module)
  10
+    results = unittest.TextTestRunner(verbosity=verbosity).run(suite)
  11
+    return results.failures, results.errors
9 12
 
  13
+def get_option_parser():
  14
+    parser = optparse.OptionParser()
  15
+    parser.add_option('-e', '--engine', dest='engine', default='sqlite', help='Database engine to test, one of [sqlite3, postgres, mysql]')
  16
+    parser.add_option('-v', '--verbosity', dest='verbosity', default=1, type='int', help='Verbosity of output')
  17
+    parser.add_option('-a', '--all', dest='all', default=False, action='store_true', help='Run all tests, including extras')
  18
+    parser.add_option('-x', '--extra', dest='extra', default=False, action='store_true', help='Run only extras tests')
  19
+    return parser
10 20
 
11 21
 if __name__ == '__main__':
12  
-    backend = os.environ.get('PEEWEE_TEST_BACKEND') or 'sqlite'
13  
-    print 'RUNNING PEEWEE TESTS WITH [%s]' % backend
14  
-    print '=============================================='
15  
-    unittest.main(module='tests')
  22
+    parser = get_option_parser()
  23
+    options, args = parser.parse_args()
16 24
 
  25
+    os.environ['PEEWEE_TEST_BACKEND'] = options.engine
  26
+    os.environ['PEEWEE_TEST_VERBOSITY'] = str(options.verbosity)
  27
+
  28
+    import tests
  29
+    from extras import tests as extras_tests
  30
+
  31
+    if options.all:
  32
+        modules = [tests, extras_tests]
  33
+    elif options.extra:
  34
+        modules = [extras_tests]
  35
+    else:
  36
+        modules = [tests]
  37
+
  38
+    results = []
  39
+    any_failures = False
  40
+    any_errors = False
  41
+    for module in modules:
  42
+        failures, errors = runtests(module, options.verbosity)
  43
+        any_failures = any_failures or bool(failures)
  44
+        any_errors = any_errors or bool(errors)
  45
+
  46
+    if any_errors:
  47
+        sys.exit(2)
  48
+    elif any_failures:
  49
+        sys.exit(1)
  50
+    sys.exit(0)
1  setup.py
@@ -13,6 +13,7 @@
13 13
     author='Charles Leifer',
14 14
     author_email='coleifer@gmail.com',
15 15
     url='http://github.com/coleifer/peewee/',
  16
+    packages=['extras'],
16 17
     py_modules=['peewee', 'pwiz'],
17 18
     classifiers=[
18 19
         'Development Status :: 4 - Beta',
7  tests.py
@@ -25,6 +25,7 @@ def emit(self, record):
25 25
 
26 26
 
27 27
 BACKEND = os.environ.get('PEEWEE_TEST_BACKEND', 'sqlite')
  28
+TEST_VERBOSITY = int(os.environ.get('PEEWEE_TEST_VERBOSITY') or 1)
28 29
 
29 30
 if BACKEND == 'postgresql':
30 31
     database_class = PostgresqlDatabase
@@ -3567,7 +3568,8 @@ def test_for_update(self):
3567 3568
             self.assertEqual(blog_title, 'b1_edited')
3568 3569
 
3569 3570
 else:
3570  
-    print 'Skipping for update tests because backend does not support'
  3571
+    if TEST_VERBOSITY > 0:
  3572
+        print 'Skipping for update tests because backend does not support'
3571 3573
 
3572 3574
 if test_db.adapter.sequence_support:
3573 3575
     class SequenceTestCase(BaseModelTestCase):
@@ -3602,7 +3604,8 @@ def test_sequence_shared(self):
3602 3604
             self.assertEqual(b2.id, a3.id - 1)
3603 3605
 
3604 3606
 else:
3605  
-    print 'Skipping sequence tests because backend does not support'
  3607
+    if TEST_VERBOSITY > 0:
  3608
+        print 'Skipping sequence tests because backend does not support'
3606 3609
 
3607 3610
 
3608 3611
 class TopologicalSortTestCase(unittest.TestCase):

0 notes on commit 781bbc7

Please sign in to comment.
Something went wrong with that request. Please try again.