diff --git a/extras/__init__.py b/extras/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/extras/signals.py b/extras/signals.py new file mode 100644 index 000000000..66824fb7e --- /dev/null +++ b/extras/signals.py @@ -0,0 +1,71 @@ +from peewee import Model as _Model + + +class Signal(object): + def __init__(self): + self._flush() + + def connect(self, receiver, name=None, sender=None): + name = name or receiver.__name__ + if name not in self._receivers: + self._receivers[name] = (receiver, sender) + self._receiver_list.append(name) + else: + raise ValueError('receiver named %s already connected' % name) + + def disconnect(self, receiver=None, name=None): + if receiver: + name = receiver.__name__ + if name: + del self._receivers[name] + self._receiver_list.remove(name) + else: + raise ValueError('a receiver or a name must be provided') + + def send(self, instance, *args, **kwargs): + sender = type(instance) + responses = [] + for name in self._receiver_list: + r, s = self._receivers[name] + if s is None or sender is s: + responses.append((r, r(sender, instance, *args, **kwargs))) + return responses + + def _flush(self): + self._receivers = {} + self._receiver_list = [] + + +pre_save = Signal() +post_save = Signal() +pre_delete = Signal() +post_delete = Signal() +pre_init = Signal() +post_init = Signal() + + +class Model(_Model): + def __init__(self, *args, **kwargs): + super(Model, self).__init__(*args, **kwargs) + pre_init.send(self) + + def prepared(self): + super(Model, self).prepared() + post_init.send(self) + + def save(self, *args, **kwargs): + created = not bool(self.get_pk()) + pre_save.send(self, created=created) + super(Model, self).save(*args, **kwargs) + post_save.send(self, created=created) + + def delete_instance(self, *args, **kwargs): + pre_delete.send(self) + super(Model, self).delete_instance(*args, **kwargs) + post_delete.send(self) + +def connect(signal, name=None, sender=None): + def decorator(fn): + signal.connect(fn, name, sender) + return fn + return decorator diff --git a/extras/tests.py b/extras/tests.py new file mode 100644 index 000000000..f10d840fe --- /dev/null +++ b/extras/tests.py @@ -0,0 +1,139 @@ +import unittest + +from peewee import * +import signals + + +db = SqliteDatabase(':memory:') + +class BaseSignalModel(signals.Model): + class Meta: + database = db + +class ModelA(BaseSignalModel): + a = CharField() + +class ModelB(BaseSignalModel): + b = CharField() + + +class SignalsTestCase(unittest.TestCase): + def setUp(self): + ModelA.create_table(True) + ModelB.create_table(True) + + def tearDown(self): + ModelA.drop_table() + ModelB.drop_table() + signals.pre_save._flush() + signals.post_save._flush() + signals.pre_delete._flush() + signals.post_delete._flush() + signals.pre_init._flush() + signals.post_init._flush() + + def test_pre_save(self): + state = [] + + @signals.connect(signals.pre_save) + def pre_save(sender, instance, created): + state.append((sender, instance, instance.get_pk(), created)) + m = ModelA() + m.save() + self.assertEqual(state, [(ModelA, m, None, True)]) + + m.save() + self.assertTrue(m.id is not None) + self.assertEqual(state[-1], (ModelA, m, m.id, False)) + + def test_post_save(self): + state = [] + + @signals.connect(signals.post_save) + def post_save(sender, instance, created): + state.append((sender, instance, instance.get_pk(), created)) + m = ModelA() + m.save() + + self.assertTrue(m.id is not None) + self.assertEqual(state, [(ModelA, m, m.id, True)]) + + m.save() + self.assertEqual(state[-1], (ModelA, m, m.id, False)) + + def test_pre_delete(self): + state = [] + + m = ModelA() + m.save() + + @signals.connect(signals.pre_delete) + def pre_delete(sender, instance): + state.append((sender, instance, ModelA.select().count())) + m.delete_instance() + self.assertEqual(state, [(ModelA, m, 1)]) + + def test_post_delete(self): + state = [] + + m = ModelA() + m.save() + + @signals.connect(signals.post_delete) + def post_delete(sender, instance): + state.append((sender, instance, ModelA.select().count())) + m.delete_instance() + self.assertEqual(state, [(ModelA, m, 0)]) + + def test_pre_init(self): + state = [] + + m = ModelA(a='a') + m.save() + + @signals.connect(signals.pre_init) + def pre_init(sender, instance): + state.append((sender, instance.a)) + + ModelA.get() + self.assertEqual(state, [(ModelA, None)]) + + def test_post_init(self): + state = [] + + m = ModelA(a='a') + m.save() + + @signals.connect(signals.post_init) + def post_init(sender, instance): + state.append((sender, instance.a)) + + ModelA.get() + self.assertEqual(state, [(ModelA, 'a')]) + + def test_sender(self): + state = [] + + @signals.connect(signals.post_save, sender=ModelA) + def post_save(sender, instance, created): + state.append(instance) + + m = ModelA.create() + self.assertEqual(state, [m]) + + m2 = ModelB.create() + self.assertEqual(state, [m]) + + def test_connect_disconnect(self): + state = [] + + @signals.connect(signals.post_save, sender=ModelA) + def post_save(sender, instance, created): + state.append(instance) + + m = ModelA.create() + self.assertEqual(state, [m]) + + signals.post_save.disconnect(post_save) + m2 = ModelA.create() + self.assertEqual(state, [m]) diff --git a/runtests.py b/runtests.py index 36d01ac99..af4599225 100755 --- a/runtests.py +++ b/runtests.py @@ -1,16 +1,50 @@ #!/usr/bin/env python +import optparse import os +import sys import unittest -def collect(): - start_dir = os.path.abspath(os.path.dirname(__file__)) - return unittest.defaultTestLoader.discover(start_dir) +def runtests(module, verbosity): + suite = unittest.TestLoader().loadTestsFromModule(module) + results = unittest.TextTestRunner(verbosity=verbosity).run(suite) + return results.failures, results.errors +def get_option_parser(): + parser = optparse.OptionParser() + parser.add_option('-e', '--engine', dest='engine', default='sqlite', help='Database engine to test, one of [sqlite3, postgres, mysql]') + parser.add_option('-v', '--verbosity', dest='verbosity', default=1, type='int', help='Verbosity of output') + parser.add_option('-a', '--all', dest='all', default=False, action='store_true', help='Run all tests, including extras') + parser.add_option('-x', '--extra', dest='extra', default=False, action='store_true', help='Run only extras tests') + return parser if __name__ == '__main__': - backend = os.environ.get('PEEWEE_TEST_BACKEND') or 'sqlite' - print 'RUNNING PEEWEE TESTS WITH [%s]' % backend - print '==============================================' - unittest.main(module='tests') + parser = get_option_parser() + options, args = parser.parse_args() + os.environ['PEEWEE_TEST_BACKEND'] = options.engine + os.environ['PEEWEE_TEST_VERBOSITY'] = str(options.verbosity) + + import tests + from extras import tests as extras_tests + + if options.all: + modules = [tests, extras_tests] + elif options.extra: + modules = [extras_tests] + else: + modules = [tests] + + results = [] + any_failures = False + any_errors = False + for module in modules: + failures, errors = runtests(module, options.verbosity) + any_failures = any_failures or bool(failures) + any_errors = any_errors or bool(errors) + + if any_errors: + sys.exit(2) + elif any_failures: + sys.exit(1) + sys.exit(0) diff --git a/setup.py b/setup.py index 1cc01bd12..2c852d04f 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ author='Charles Leifer', author_email='coleifer@gmail.com', url='http://github.com/coleifer/peewee/', + packages=['extras'], py_modules=['peewee', 'pwiz'], classifiers=[ 'Development Status :: 4 - Beta', diff --git a/tests.py b/tests.py index c2ef1966d..4e5d88a17 100644 --- a/tests.py +++ b/tests.py @@ -25,6 +25,7 @@ def emit(self, record): BACKEND = os.environ.get('PEEWEE_TEST_BACKEND', 'sqlite') +TEST_VERBOSITY = int(os.environ.get('PEEWEE_TEST_VERBOSITY') or 1) if BACKEND == 'postgresql': database_class = PostgresqlDatabase @@ -3567,7 +3568,8 @@ def test_for_update(self): self.assertEqual(blog_title, 'b1_edited') else: - print 'Skipping for update tests because backend does not support' + if TEST_VERBOSITY > 0: + print 'Skipping for update tests because backend does not support' if test_db.adapter.sequence_support: class SequenceTestCase(BaseModelTestCase): @@ -3602,7 +3604,8 @@ def test_sequence_shared(self): self.assertEqual(b2.id, a3.id - 1) else: - print 'Skipping sequence tests because backend does not support' + if TEST_VERBOSITY > 0: + print 'Skipping sequence tests because backend does not support' class TopologicalSortTestCase(unittest.TestCase):