Skip to content

Commit

Permalink
Add hybrid tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
coleifer committed Apr 23, 2017
1 parent 54a6725 commit 08790cb
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 9 deletions.
10 changes: 1 addition & 9 deletions tests/__init__.py
Expand Up @@ -17,18 +17,10 @@
from .sql import *
from .transactions import *

from .hybrid import *
from .shortcuts import *
#from playhouse.tests.test_apis import *
#from playhouse.tests.test_compound_queries import *
#from playhouse.tests.test_database import *
#from playhouse.tests.test_fields import *
#from playhouse.tests.test_helpers import *
#from playhouse.tests.test_introspection import *
#from playhouse.tests.test_keys import *
#from playhouse.tests.test_models import *
#from playhouse.tests.test_queries import *
#from playhouse.tests.test_query_results import *
#from playhouse.tests.test_transactions import *


if __name__ == '__main__':
Expand Down
101 changes: 101 additions & 0 deletions tests/hybrid.py
@@ -0,0 +1,101 @@
from peewee import *
from playhouse.hybrid import *

from .base import ModelTestCase
from .base import TestModel
from .base import get_in_memory_db


class Interval(TestModel):
start = IntegerField()
end = IntegerField()

@hybrid_property
def length(self):
return self.end - self.start

@hybrid_method
def contains(self, point):
return (self.start <= point) & (point < self.end)

@hybrid_property
def radius(self):
return int(abs(self.length) / 2)

@radius.expression
def radius(cls):
return fn.ABS(cls.length) / 2


class Person(TestModel):
first = TextField()
last = TextField()

@hybrid_property
def full_name(self):
return self.first + ' ' + self.last


class TestHybridProperties(ModelTestCase):
database = get_in_memory_db()
requires = [Interval, Person]

def setUp(self):
super(TestHybridProperties, self).setUp()
intervals = (
(1, 5),
(2, 6),
(3, 5),
(2, 5))
for start, end in intervals:
Interval.create(start=start, end=end)

def test_hybrid_property(self):
query = Interval.select().where(Interval.length == 4)
self.assertSQL(query, (
'SELECT "t1"."id", "t1"."start", "t1"."end" '
'FROM "interval" AS "t1" '
'WHERE (("t1"."end" - "t1"."start") = ?)'), [4])

results = sorted((i.start, i.end) for i in query)
self.assertEqual(results, [(1, 5), (2, 6)])

query = Interval.select().order_by(Interval.id)
self.assertEqual([i.length for i in query], [4, 4, 2, 3])

def test_hybrid_method(self):
query = Interval.select().where(Interval.contains(2))
self.assertSQL(query, (
'SELECT "t1"."id", "t1"."start", "t1"."end" '
'FROM "interval" AS "t1" '
'WHERE (("t1"."start" <= ?) AND ("t1"."end" > ?))'), [2, 2])

results = sorted((i.start, i.end) for i in query)
self.assertEqual(results, [(1, 5), (2, 5), (2, 6)])

query = Interval.select().order_by(Interval.id)
self.assertEqual([i.contains(2) for i in query], [1, 1, 0, 1])

def test_expression(self):
query = Interval.select().where(Interval.radius == 2)
self.assertSQL(query, (
'SELECT "t1"."id", "t1"."start", "t1"."end" '
'FROM "interval" AS "t1" '
'WHERE ((ABS("t1"."end" - "t1"."start") / ?) = ?)'), [2, 2])

self.assertEqual(sorted((i.start, i.end) for i in query),
[(1, 5), (2, 6)])

query = Interval.select().order_by(Interval.id)
self.assertEqual([i.radius for i in query], [2, 2, 1, 1])

def test_string_fields(self):
huey = Person.create(first='huey', last='cat')
zaizee = Person.create(first='zaizee', last='kitten')

self.assertEqual(huey.full_name, 'huey cat')
self.assertEqual(zaizee.full_name, 'zaizee kitten')

query = Person.select().where(Person.full_name.startswith('huey c'))
huey_db = query.get()
self.assertEqual(huey_db.id, huey.id)

0 comments on commit 08790cb

Please sign in to comment.