Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Fetching contributors…

Cannot retrieve contributors at this time

1812 lines (1395 sloc) 66.485 kb
# encoding=utf-8
from __future__ import with_statement
import datetime
import decimal
import logging
import os
import Queue
import threading
import unittest
from peewee import *
from peewee import QueryCompiler, R, SelectQuery, RawQuery, InsertQuery,\
UpdateQuery, DeleteQuery, logger, transaction, sort_models_topologically
class QueryLogHandler(logging.Handler):
def __init__(self, *args, **kwargs):
self.queries = []
logging.Handler.__init__(self, *args, **kwargs)
def emit(self, record):
self.queries.append(record)
#
# JUNK TO ALLOW TESTING OF MULTIPLE DATABASE BACKENDS
#
BACKEND = os.environ.get('PEEWEE_TEST_BACKEND', 'sqlite')
TEST_VERBOSITY = int(os.environ.get('PEEWEE_TEST_VERBOSITY') or 1)
database_params = {}
if BACKEND == 'postgresql':
database_class = PostgresqlDatabase
database_name = 'peewee_test'
elif BACKEND == 'mysql':
database_class = MySQLDatabase
database_name = 'peewee_test'
elif BACKEND == 'apsw':
from extras.apsw_ext import *
database_class = APSWDatabase
database_name = 'tmp.db'
database_params['timeout'] = 1000
else:
database_class = SqliteDatabase
database_name = 'tmp.db'
import sqlite3
print 'SQLITE VERSION: %s' % sqlite3.version
#
# TEST-ONLY QUERY COMPILER USED TO CREATE "predictable" QUERIES
#
class TestQueryCompiler(QueryCompiler):
def _max_alias(self, am):
return 0
def calculate_alias_map(self, query, start=1):
alias_map = {query.model_class: query.model_class._meta.db_table}
for model, joins in query._joins.items():
if model not in alias_map:
alias_map[model] = model._meta.db_table
for join in joins:
if join.model_class not in alias_map:
alias_map[join.model_class] = join.model_class._meta.db_table
return alias_map
class TestDatabase(database_class):
compiler_class = TestQueryCompiler
field_overrides = {}
interpolation = '?'
op_overrides = {}
quote_char = '"'
test_db = database_class(database_name, **database_params)
query_db = TestDatabase(database_name, **database_params)
compiler = query_db.get_compiler()
#
# BASE MODEL CLASS
#
class TestModel(Model):
class Meta:
database = test_db
#
# MODEL CLASSES USED BY TEST CASES
#
class User(TestModel):
username = CharField()
class Meta:
db_table = 'users'
class Blog(TestModel):
user = ForeignKeyField(User)
title = CharField(max_length=25)
content = TextField(default='')
pub_date = DateTimeField(null=True)
pk = PrimaryKeyField()
def __unicode__(self):
return '%s: %s' % (self.user.username, self.title)
class Comment(TestModel):
blog = ForeignKeyField(Blog, related_name='comments')
comment = CharField()
class Relationship(TestModel):
from_user = ForeignKeyField(User, related_name='relationships')
to_user = ForeignKeyField(User, related_name='related_to')
class NullModel(TestModel):
char_field = CharField(null=True)
text_field = TextField(null=True)
datetime_field = DateTimeField(null=True)
int_field = IntegerField(null=True)
float_field = FloatField(null=True)
decimal_field1 = DecimalField(null=True)
decimal_field2 = DecimalField(decimal_places=2, null=True)
double_field = DoubleField(null=True)
bigint_field = BigIntegerField(null=True)
date_field = DateField(null=True)
time_field = TimeField(null=True)
boolean_field = BooleanField(null=True)
class UniqueModel(TestModel):
name = CharField(unique=True)
class OrderedModel(TestModel):
title = CharField()
created = DateTimeField(default=datetime.datetime.now)
class Meta:
order_by = ('-created',)
class Category(TestModel):
parent = ForeignKeyField('self', related_name='children', null=True)
name = CharField()
class UserCategory(TestModel):
user = ForeignKeyField(User)
category = ForeignKeyField(Category)
class NonIntModel(TestModel):
pk = CharField(primary_key=True)
data = CharField()
class NonIntRelModel(TestModel):
non_int_model = ForeignKeyField(NonIntModel, related_name='nr')
class DBUser(TestModel):
user_id = PrimaryKeyField(db_column='db_user_id')
username = CharField(db_column='db_username')
class DBBlog(TestModel):
blog_id = PrimaryKeyField(db_column='db_blog_id')
title = CharField(db_column='db_title')
user = ForeignKeyField(DBUser, db_column='db_user')
class SeqModelA(TestModel):
id = IntegerField(primary_key=True, sequence='just_testing_seq')
num = IntegerField()
class SeqModelB(TestModel):
id = IntegerField(primary_key=True, sequence='just_testing_seq')
other_num = IntegerField()
class MultiIndexModel(TestModel):
f1 = CharField()
f2 = CharField()
f3 = CharField()
class Meta:
indexes = (
(('f1', 'f2'), True),
(('f2', 'f3'), False),
)
class BlogTwo(Blog):
title = TextField()
extra_field = CharField()
class Parent(TestModel):
data = CharField()
class Child(TestModel):
parent = ForeignKeyField(Parent)
class Orphan(TestModel):
parent = ForeignKeyField(Parent, null=True)
class ChildPet(TestModel):
child = ForeignKeyField(Child)
class OrphanPet(TestModel):
orphan = ForeignKeyField(Orphan)
MODELS = [User, Blog, Comment, Relationship, NullModel, UniqueModel, OrderedModel, Category, UserCategory,
NonIntModel, NonIntRelModel, DBUser, DBBlog, SeqModelA, SeqModelB, MultiIndexModel, BlogTwo]
INT = test_db.interpolation
def drop_tables(only=None):
for model in reversed(MODELS):
if only is None or model in only:
model.drop_table(True)
def create_tables(only=None):
for model in MODELS:
if only is None or model in only:
model.create_table()
#
# BASE TEST CASE USED BY ALL TESTS
#
class BasePeeweeTestCase(unittest.TestCase):
def setUp(self):
self.qh = QueryLogHandler()
logger.setLevel(logging.DEBUG)
logger.addHandler(self.qh)
def tearDown(self):
logger.removeHandler(self.qh)
def queries(self):
return [x.msg for x in self.qh.queries]
def parse_expr(self, query, expr_list):
am = compiler.calculate_alias_map(query)
return compiler.parse_expr_list(expr_list, am)
def parse_node(self, query, node):
am = compiler.calculate_alias_map(query)
return compiler.parse_query_node(node, am)
def make_fn(fn_name, attr_name):
def inner(self, query, expected, expected_params):
fn = getattr(self, fn_name)
att = getattr(query, attr_name)
sql, params = fn(query, att)
self.assertEqual(sql, expected)
self.assertEqual(params, expected_params)
return inner
assertSelect = make_fn('parse_expr', '_select')
assertWhere = make_fn('parse_node', '_where')
assertGroupBy = make_fn('parse_expr', '_group_by')
assertHaving = make_fn('parse_node', '_having')
assertOrderBy = make_fn('parse_expr', '_order_by')
def assertJoins(self, sq, exp_joins):
am = compiler.calculate_alias_map(sq)
joins = compiler.parse_joins(sq._joins, sq.model_class, am)
self.assertEqual(sorted(joins), sorted(exp_joins))
def assertDict(self, qd, expected, expected_params):
sets, params = compiler._parse_field_dictionary(qd)
self.assertEqual(sets, expected)
self.assertEqual(params, expected_params)
def assertUpdate(self, uq, expected, expected_params):
self.assertDict(uq._update, expected, expected_params)
def assertInsert(self, uq, expected, expected_params):
self.assertDict(uq._insert, expected, expected_params)
#
# BASIC TESTS OF QUERY TYPES AND INTERNAL DATA STRUCTURES
#
class SelectTestCase(BasePeeweeTestCase):
def test_selection(self):
sq = SelectQuery(User)
self.assertSelect(sq, 'users."id", users."username"', [])
sq = SelectQuery(Blog, Blog.pk, Blog.title, Blog.user, User.username).join(User)
self.assertSelect(sq, 'blog."pk", blog."title", blog."user_id", users."username"', [])
sq = SelectQuery(User, fn.Lower(fn.Substr(User.username, 0, 1)).alias('lu'), fn.Count(Blog.pk)).join(Blog)
self.assertSelect(sq, 'Lower(Substr(users."username", ?, ?)) AS lu, Count(blog."pk")', [0, 1])
sq = SelectQuery(User, User.username, fn.Count(Blog.select().where(Blog.user == User.id)))
self.assertSelect(sq, 'users."username", Count((SELECT blog."pk" FROM "blog" AS blog WHERE blog."user_id" = users."id"))', [])
def test_joins(self):
sq = SelectQuery(User).join(Blog)
self.assertJoins(sq, ['INNER JOIN "blog" AS blog ON users."id" = blog."user_id"'])
sq = SelectQuery(Blog).join(User, JOIN_LEFT_OUTER)
self.assertJoins(sq, ['LEFT OUTER JOIN "users" AS users ON blog."user_id" = users."id"'])
sq = SelectQuery(User).join(Relationship)
self.assertJoins(sq, ['INNER JOIN "relationship" AS relationship ON users."id" = relationship."from_user_id"'])
sq = SelectQuery(User).join(Relationship, on=Relationship.to_user)
self.assertJoins(sq, ['INNER JOIN "relationship" AS relationship ON users."id" = relationship."to_user_id"'])
sq = SelectQuery(User).join(Relationship, JOIN_LEFT_OUTER, Relationship.to_user)
self.assertJoins(sq, ['LEFT OUTER JOIN "relationship" AS relationship ON users."id" = relationship."to_user_id"'])
def test_join_self_referential(self):
sq = SelectQuery(Category).join(Category)
self.assertJoins(sq, ['INNER JOIN "category" AS category ON category."parent_id" = category."id"'])
def test_join_both_sides(self):
sq = SelectQuery(Blog).join(Comment).switch(Blog).join(User)
self.assertJoins(sq, [
'INNER JOIN "comment" AS comment ON blog."pk" = comment."blog_id"',
'INNER JOIN "users" AS users ON blog."user_id" = users."id"',
])
sq = SelectQuery(Blog).join(User).switch(Blog).join(Comment)
self.assertJoins(sq, [
'INNER JOIN "users" AS users ON blog."user_id" = users."id"',
'INNER JOIN "comment" AS comment ON blog."pk" = comment."blog_id"',
])
def test_join_switching(self):
class Artist(TestModel):
pass
class Track(TestModel):
artist = ForeignKeyField(Artist)
class Release(TestModel):
artist = ForeignKeyField(Artist)
class ReleaseTrack(TestModel):
track = ForeignKeyField(Track)
release = ForeignKeyField(Release)
class Genre(TestModel):
pass
class TrackGenre(TestModel):
genre = ForeignKeyField(Genre)
track = ForeignKeyField(Track)
multiple_first = Track.select().join(ReleaseTrack).join(Release).switch(Track).join(Artist).switch(Track).join(TrackGenre).join(Genre)
self.assertSelect(multiple_first, 'track."id", track."artist_id"', [])
self.assertJoins(multiple_first, [
'INNER JOIN "artist" AS artist ON track."artist_id" = artist."id"',
'INNER JOIN "genre" AS genre ON trackgenre."genre_id" = genre."id"',
'INNER JOIN "release" AS release ON releasetrack."release_id" = release."id"',
'INNER JOIN "releasetrack" AS releasetrack ON track."id" = releasetrack."track_id"',
'INNER JOIN "trackgenre" AS trackgenre ON track."id" = trackgenre."track_id"',
])
single_first = Track.select().join(Artist).switch(Track).join(ReleaseTrack).join(Release).switch(Track).join(TrackGenre).join(Genre)
self.assertSelect(single_first, 'track."id", track."artist_id"', [])
self.assertJoins(single_first, [
'INNER JOIN "artist" AS artist ON track."artist_id" = artist."id"',
'INNER JOIN "genre" AS genre ON trackgenre."genre_id" = genre."id"',
'INNER JOIN "release" AS release ON releasetrack."release_id" = release."id"',
'INNER JOIN "releasetrack" AS releasetrack ON track."id" = releasetrack."track_id"',
'INNER JOIN "trackgenre" AS trackgenre ON track."id" = trackgenre."track_id"',
])
def test_where(self):
sq = SelectQuery(User).where(User.id < 5)
self.assertWhere(sq, 'users."id" < ?', [5])
def test_where_lists(self):
sq = SelectQuery(User).where(User.username << ['u1', 'u2'])
self.assertWhere(sq, 'users."username" IN (?,?)', ['u1', 'u2'])
sq = SelectQuery(User).where((User.username << ['u1', 'u2']) | (User.username << ['u3', 'u4']))
self.assertWhere(sq, '(users."username" IN (?,?) OR users."username" IN (?,?))', ['u1', 'u2', 'u3', 'u4'])
def test_where_joins(self):
sq = SelectQuery(User).where(
((User.id == 1) | (User.id == 2)) &
((Blog.pk == 3) | (Blog.pk == 4))
).where(User.id == 5).join(Blog)
self.assertWhere(sq, '(users."id" = ? OR users."id" = ?) AND (blog."pk" = ? OR blog."pk" = ?) AND users."id" = ?', [1, 2, 3, 4, 5])
def test_where_functions(self):
sq = SelectQuery(User).where(fn.Lower(fn.Substr(User.username, 0, 1)) == 'a')
self.assertWhere(sq, 'Lower(Substr(users."username", ?, ?)) = ?', [0, 1, 'a'])
def test_where_subqueries(self):
sq = SelectQuery(User).where(User.id << User.select().where(User.username=='u1'))
self.assertWhere(sq, 'users."id" IN (SELECT users."id" FROM "users" AS users WHERE users."username" = ?)', ['u1'])
sq = SelectQuery(Blog).where((Blog.pk == 3) | (Blog.user << User.select().where(User.username << ['u1', 'u2'])))
self.assertWhere(sq, '(blog."pk" = ? OR blog."user_id" IN (SELECT users."id" FROM "users" AS users WHERE users."username" IN (?,?)))', [3, 'u1', 'u2'])
def test_where_fk(self):
sq = SelectQuery(Blog).where(Blog.user == User(id=100))
self.assertWhere(sq, 'blog."user_id" = ?', [100])
sq = SelectQuery(Blog).where(Blog.user << [User(id=100), User(id=101)])
self.assertWhere(sq, 'blog."user_id" IN (?,?)', [100, 101])
def test_where_negation(self):
sq = SelectQuery(Blog).where(~(Blog.title == 'foo'))
self.assertWhere(sq, 'NOT blog."title" = ?', ['foo'])
sq = SelectQuery(Blog).where(~((Blog.title == 'foo') | (Blog.title == 'bar')))
self.assertWhere(sq, '(NOT (blog."title" = ? OR blog."title" = ?))', ['foo', 'bar'])
sq = SelectQuery(Blog).where(~((Blog.title == 'foo') & (Blog.title == 'bar')) & (Blog.title == 'baz'))
self.assertWhere(sq, '(NOT (blog."title" = ? AND blog."title" = ?)) AND blog."title" = ?', ['foo', 'bar', 'baz'])
sq = SelectQuery(Blog).where(~((Blog.title == 'foo') & (Blog.title == 'bar')) & ((Blog.title == 'baz') & (Blog.title == 'fizz')))
self.assertWhere(sq, '(NOT (blog."title" = ? AND blog."title" = ?)) AND (blog."title" = ? AND blog."title" = ?)', ['foo', 'bar', 'baz', 'fizz'])
def test_where_chaining_collapsing(self):
sq = SelectQuery(User).where(User.id == 1).where(User.id == 2).where(User.id == 3)
self.assertWhere(sq, 'users."id" = ? AND users."id" = ? AND users."id" = ?', [1, 2, 3])
sq = SelectQuery(User).where((User.id == 1) & (User.id == 2)).where(User.id == 3)
self.assertWhere(sq, 'users."id" = ? AND users."id" = ? AND users."id" = ?', [1, 2, 3])
sq = SelectQuery(User).where((User.id == 1) | (User.id == 2)).where(User.id == 3)
self.assertWhere(sq, '(users."id" = ? OR users."id" = ?) AND users."id" = ?', [1, 2, 3])
sq = SelectQuery(User).where(User.id == 1).where((User.id == 2) & (User.id == 3))
self.assertWhere(sq, 'users."id" = ? AND users."id" = ? AND users."id" = ?', [1, 2, 3])
sq = SelectQuery(User).where(User.id == 1).where((User.id == 2) | (User.id == 3))
self.assertWhere(sq, '(users."id" = ?) AND (users."id" = ? OR users."id" = ?)', [1, 2, 3])
sq = SelectQuery(User).where(~(User.id == 1)).where(User.id == 2).where(~(User.id == 3))
self.assertWhere(sq, '(users."id" = ? AND users."id" = ?) AND NOT users."id" = ?', [1, 2, 3])
def test_grouping(self):
sq = SelectQuery(User).group_by(User.id)
self.assertGroupBy(sq, 'users."id"', [])
sq = SelectQuery(User).group_by(User)
self.assertGroupBy(sq, 'users."id", users."username"', [])
def test_having(self):
sq = SelectQuery(User, fn.Count(Blog.pk)).join(Blog).group_by(User).having(
fn.Count(Blog.pk) > 2
)
self.assertHaving(sq, 'Count(blog."pk") > ?', [2])
sq = SelectQuery(User, fn.Count(Blog.pk)).join(Blog).group_by(User).having(
(fn.Count(Blog.pk) > 10) | (fn.Count(Blog.pk) < 2)
)
self.assertHaving(sq, '(Count(blog."pk") > ? OR Count(blog."pk") < ?)', [10, 2])
def test_ordering(self):
sq = SelectQuery(User).join(Blog).order_by(Blog.title)
self.assertOrderBy(sq, 'blog."title"', [])
sq = SelectQuery(User).join(Blog).order_by(Blog.title.asc())
self.assertOrderBy(sq, 'blog."title" ASC', [])
sq = SelectQuery(User).join(Blog).order_by(Blog.title.desc())
self.assertOrderBy(sq, 'blog."title" DESC', [])
sq = SelectQuery(User).join(Blog).order_by(User.username.desc(), Blog.title.asc())
self.assertOrderBy(sq, 'users."username" DESC, blog."title" ASC', [])
base_sq = SelectQuery(User, User.username, fn.Count(Blog.pk).alias('count')).join(Blog).group_by(User.username)
sq = base_sq.order_by(fn.Count(Blog.pk).desc())
self.assertOrderBy(sq, 'Count(blog."pk") DESC', [])
sq = base_sq.order_by(R('count'))
self.assertOrderBy(sq, 'count', [])
sq = OrderedModel.select()
self.assertOrderBy(sq, 'orderedmodel."created" DESC', [])
sq = OrderedModel.select().order_by(OrderedModel.id.asc())
self.assertOrderBy(sq, 'orderedmodel."id" ASC', [])
def test_paginate(self):
sq = SelectQuery(User).paginate(1, 20)
self.assertEqual(sq._limit, 20)
self.assertEqual(sq._offset, 0)
sq = SelectQuery(User).paginate(3, 30)
self.assertEqual(sq._limit, 30)
self.assertEqual(sq._offset, 60)
class UpdateTestCase(BasePeeweeTestCase):
def test_update(self):
uq = UpdateQuery(User, {User.username: 'updated'})
self.assertUpdate(uq, [('"username"', '?')], ['updated'])
uq = UpdateQuery(Blog, {Blog.user: User(id=100, username='foo')})
self.assertUpdate(uq, [('"user_id"', '?')], [100])
uq = UpdateQuery(User, {User.id: User.id + 5})
self.assertUpdate(uq, [('"id"', '("id" + ?)')], [5])
uq = UpdateQuery(User, {User.id: 5 * (3 + User.id)})
self.assertUpdate(uq, [('"id"', '(? * (? + "id"))')], [5, 3])
def test_where(self):
uq = UpdateQuery(User, {User.username: 'updated'}).where(User.id == 2)
self.assertWhere(uq, 'users."id" = ?', [2])
class InsertTestCase(BasePeeweeTestCase):
def test_insert(self):
iq = InsertQuery(User, {User.username: 'inserted'})
self.assertInsert(iq, [('"username"', '?')], ['inserted'])
class DeleteTestCase(BasePeeweeTestCase):
def test_where(self):
dq = DeleteQuery(User).where(User.id == 2)
self.assertWhere(dq, 'users."id" = ?', [2])
class RawTestCase(BasePeeweeTestCase):
def test_raw(self):
q = 'SELECT * FROM "users" WHERE id=?'
rq = RawQuery(User, q, 100)
self.assertEqual(rq.sql(compiler), (q, [100]))
class SugarTestCase(BasePeeweeTestCase):
# test things like filter, annotate, aggregate
def test_filter(self):
sq = User.filter(username='u1')
self.assertJoins(sq, [])
self.assertWhere(sq, 'users."username" = ?', ['u1'])
sq = Blog.filter(user__username='u1')
self.assertJoins(sq, ['INNER JOIN "users" AS users ON blog."user_id" = users."id"'])
self.assertWhere(sq, 'users."username" = ?', ['u1'])
sq = Blog.filter(user__username__in=['u1', 'u2'], comments__comment='hurp')
self.assertJoins(sq, [
'INNER JOIN "comment" AS comment ON blog."pk" = comment."blog_id"',
'INNER JOIN "users" AS users ON blog."user_id" = users."id"',
])
self.assertWhere(sq, 'comment."comment" = ? AND users."username" IN (?,?)', ['hurp', 'u1', 'u2'])
sq = Blog.filter(user__username__in=['u1', 'u2']).filter(comments__comment='hurp')
self.assertJoins(sq, [
'INNER JOIN "users" AS users ON blog."user_id" = users."id"',
'INNER JOIN "comment" AS comment ON blog."pk" = comment."blog_id"',
])
self.assertWhere(sq, 'users."username" IN (?,?) AND comment."comment" = ?', ['u1', 'u2', 'hurp'])
def test_filter_dq(self):
sq = User.filter(DQ(username='u1') | DQ(username='u2'))
self.assertJoins(sq, [])
self.assertWhere(sq, '(users."username" = ? OR users."username" = ?)', ['u1', 'u2'])
sq = Comment.filter(DQ(blog__user__username='u1') | DQ(blog__title='b1'), DQ(comment='c1'))
self.assertJoins(sq, [
'INNER JOIN "blog" AS blog ON comment."blog_id" = blog."pk"',
'INNER JOIN "users" AS users ON blog."user_id" = users."id"',
])
self.assertWhere(sq, '(users."username" = ? OR blog."title" = ?) AND comment."comment" = ?', ['u1', 'b1', 'c1'])
sq = Blog.filter(DQ(user__username='u1') | DQ(comments__comment='c1'))
self.assertJoins(sq, [
'INNER JOIN "comment" AS comment ON blog."pk" = comment."blog_id"',
'INNER JOIN "users" AS users ON blog."user_id" = users."id"',
])
self.assertWhere(sq, '(users."username" = ? OR comment."comment" = ?)', ['u1', 'c1'])
def test_annotate(self):
sq = User.select().annotate(Blog)
self.assertSelect(sq, 'users."id", users."username", Count(blog."pk") AS count', [])
self.assertJoins(sq, ['INNER JOIN "blog" AS blog ON users."id" = blog."user_id"'])
self.assertWhere(sq, '', [])
self.assertGroupBy(sq, 'users."id", users."username"', [])
sq = User.select(User.username).annotate(Blog, fn.Sum(Blog.pk).alias('sum')).where(User.username == 'foo')
self.assertSelect(sq, 'users."username", Sum(blog."pk") AS sum', [])
self.assertJoins(sq, ['INNER JOIN "blog" AS blog ON users."id" = blog."user_id"'])
self.assertWhere(sq, 'users."username" = ?', ['foo'])
self.assertGroupBy(sq, 'users."username"', [])
sq = User.select(User.username).annotate(Blog).annotate(Blog, fn.Max(Blog.pk).alias('mx'))
self.assertSelect(sq, 'users."username", Count(blog."pk") AS count, Max(blog."pk") AS mx', [])
self.assertJoins(sq, ['INNER JOIN "blog" AS blog ON users."id" = blog."user_id"'])
self.assertWhere(sq, '', [])
self.assertGroupBy(sq, 'users."username"', [])
sq = User.select().annotate(Blog).order_by(R('count DESC'))
self.assertSelect(sq, 'users."id", users."username", Count(blog."pk") AS count', [])
self.assertOrderBy(sq, 'count DESC', [])
sq = User.select().join(Blog, JOIN_LEFT_OUTER).switch(User).annotate(Blog)
self.assertSelect(sq, 'users."id", users."username", Count(blog."pk") AS count', [])
self.assertJoins(sq, ['LEFT OUTER JOIN "blog" AS blog ON users."id" = blog."user_id"'])
self.assertWhere(sq, '', [])
self.assertGroupBy(sq, 'users."id", users."username"', [])
def test_aggregate(self):
sq = User.select().where(User.id < 10)._aggregate()
self.assertSelect(sq, 'Count(users."id")', [])
self.assertWhere(sq, 'users."id" < ?', [10])
#
# TEST CASE USED TO PROVIDE ACCESS TO DATABASE
# FOR EXECUTION OF "LIVE" QUERIES
#
class ModelTestCase(BasePeeweeTestCase):
requires = None
def setUp(self):
super(ModelTestCase, self).setUp()
drop_tables(self.requires)
create_tables(self.requires)
def tearDown(self):
drop_tables(self.requires)
def create_user(self, username):
return User.create(username=username)
def create_users(self, n):
for i in range(n):
self.create_user('u%d' % (i + 1))
class QueryResultWrapperTestCase(ModelTestCase):
requires = [User, Blog, Comment]
def test_iteration(self):
self.create_users(10)
query_start = len(self.queries())
sq = User.select()
qr = sq.execute()
first_five = []
for i, u in enumerate(qr):
first_five.append(u.username)
if i == 4:
break
self.assertEqual(first_five, ['u1', 'u2', 'u3', 'u4', 'u5'])
another_iter = [u.username for u in qr]
self.assertEqual(another_iter, ['u%d' % i for i in range(1, 11)])
another_iter = [u.username for u in qr]
self.assertEqual(another_iter, ['u%d' % i for i in range(1, 11)])
# only 1 query for these iterations
self.assertEqual(len(self.queries()) - query_start, 1)
def test_iterator(self):
self.create_users(10)
qc = len(self.queries())
qr = User.select().execute()
usernames = [u.username for u in qr.iterator()]
self.assertEqual(usernames, ['u%d' % i for i in range(1, 11)])
qc1 = len(self.queries())
self.assertEqual(qc1 - qc, 1)
self.assertTrue(qr._populated)
self.assertEqual(qr._result_cache, [])
again = [u.username for u in qr]
self.assertEqual(again, [])
qc2 = len(self.queries())
self.assertEqual(qc2 - qc1, 0)
qr = User.select().where(User.username == 'xxx').execute()
usernames = [u.username for u in qr.iterator()]
self.assertEqual(usernames, [])
def test_select_related(self):
u1 = User.create(username='u1')
u2 = User.create(username='u2')
b1 = Blog.create(user=u1, title='b1')
b2 = Blog.create(user=u2, title='b2')
c11 = Comment.create(blog=b1, comment='c11')
c12 = Comment.create(blog=b1, comment='c12')
c21 = Comment.create(blog=b2, comment='c21')
c22 = Comment.create(blog=b2, comment='c22')
# missing comment.blog_id
qc = len(self.queries())
comments = Comment.select(Comment.id, Comment.comment, Blog.pk, Blog.title).join(Blog).where(Blog.title == 'b1').order_by(Comment.id)
self.assertEqual([c.blog.title for c in comments], ['b1', 'b1'])
self.assertEqual(len(self.queries()) - qc, 1)
# missing blog.pk
qc = len(self.queries())
comments = Comment.select(Comment.id, Comment.comment, Comment.blog, Blog.title).join(Blog).where(Blog.title == 'b2').order_by(Comment.id)
self.assertEqual([c.blog.title for c in comments], ['b2', 'b2'])
self.assertEqual(len(self.queries()) - qc, 1)
# both but going up 2 levels
qc = len(self.queries())
comments = Comment.select(Comment, Blog, User).join(Blog).join(User).where(User.username == 'u1').order_by(Comment.id)
self.assertEqual([c.comment for c in comments], ['c11', 'c12'])
self.assertEqual([c.blog.title for c in comments], ['b1', 'b1'])
self.assertEqual([c.blog.user.username for c in comments], ['u1', 'u1'])
self.assertEqual(len(self.queries()) - qc, 1)
qc = len(self.queries())
comments = Comment.select().join(Blog).join(User).where(User.username == 'u1').order_by(Comment.id)
self.assertEqual([c.blog.user.username for c in comments], ['u1', 'u1'])
self.assertEqual(len(self.queries()) - qc, 5)
def test_naive(self):
u1 = User.create(username='u1')
u2 = User.create(username='u2')
b1 = Blog.create(user=u1, title='b1')
b2 = Blog.create(user=u2, title='b2')
users = User.select().naive()
self.assertEqual([u.username for u in users], ['u1', 'u2'])
users = User.select(User, Blog).join(Blog).naive()
self.assertEqual([u.username for u in users], ['u1', 'u2'])
self.assertEqual([u.title for u in users], ['b1', 'b2'])
class ModelQueryTestCase(ModelTestCase):
requires = [User, Blog]
def create_users_blogs(self, n=10, nb=5):
for i in range(n):
u = User.create(username='u%d' % i)
for j in range(nb):
b = Blog.create(title='b-%d-%d' % (i, j), content=str(j), user=u)
def test_select(self):
self.create_users_blogs()
users = User.select().where(User.username << ['u0', 'u5']).order_by(User.username)
self.assertEqual([u.username for u in users], ['u0', 'u5'])
blogs = Blog.select().join(User).where(
(User.username << ['u0', 'u3']) &
(Blog.content == '4')
).order_by(Blog.title)
self.assertEqual([b.title for b in blogs], ['b-0-4', 'b-3-4'])
users = User.select().paginate(2, 3)
self.assertEqual([u.username for u in users], ['u3', 'u4', 'u5'])
def test_update(self):
self.create_users(5)
uq = User.update(username='u-edited').where(User.username << ['u1', 'u2', 'u3'])
self.assertEqual([u.username for u in User.select().order_by(User.id)], ['u1', 'u2', 'u3', 'u4', 'u5'])
uq.execute()
self.assertEqual([u.username for u in User.select().order_by(User.id)], ['u-edited', 'u-edited', 'u-edited', 'u4', 'u5'])
def test_insert(self):
iq = User.insert(username='u1')
self.assertEqual(User.select().count(), 0)
uid = iq.execute()
self.assertTrue(uid > 0)
self.assertEqual(User.select().count(), 1)
u = User.get(id=uid)
self.assertEqual(u.username, 'u1')
def test_delete(self):
self.create_users(5)
dq = User.delete().where(User.username << ['u1', 'u2', 'u3'])
self.assertEqual(User.select().count(), 5)
nr = dq.execute()
self.assertEqual(nr, 3)
self.assertEqual([u.username for u in User.select()], ['u4', 'u5'])
def test_raw(self):
self.create_users(3)
qc = len(self.queries())
rq = User.raw('select * from users where username IN (%s,%s)' % (INT,INT), 'u1', 'u3')
self.assertEqual([u.username for u in rq], ['u1', 'u3'])
# iterate again
self.assertEqual([u.username for u in rq], ['u1', 'u3'])
self.assertEqual(len(self.queries()) - qc, 1)
rq = User.raw('select id, username, %s as secret from users where username = %s' % (INT,INT), 'sh', 'u2')
self.assertEqual([u.secret for u in rq], ['sh'])
self.assertEqual([u.username for u in rq], ['u2'])
class ModelAPITestCase(ModelTestCase):
requires = [User, Blog, Category, UserCategory]
def test_related_name(self):
u1 = self.create_user('u1')
u2 = self.create_user('u2')
b11 = Blog.create(user=u1, title='b11')
b12 = Blog.create(user=u1, title='b12')
b2 = Blog.create(user=u2, title='b2')
self.assertEqual([b.title for b in u1.blog_set], ['b11', 'b12'])
self.assertEqual([b.title for b in u2.blog_set], ['b2'])
def test_fk_exceptions(self):
c1 = Category.create(name='c1')
c2 = Category.create(parent=c1, name='c2')
self.assertEqual(c1.parent, None)
self.assertEqual(c2.parent, c1)
c2_db = Category.get(Category.id == c2.id)
self.assertEqual(c2_db.parent, c1)
u = self.create_user('u1')
b = Blog.create(user=u, title='b')
b2 = Blog(title='b2')
self.assertEqual(b.user, u)
self.assertRaises(User.DoesNotExist, getattr, b2, 'user')
def test_fk_ints(self):
c1 = Category.create(name='c1')
c2 = Category.create(name='c2', parent=c1.id)
c2_db = Category.get(Category.id == c2.id)
self.assertEqual(c2_db.parent, c1)
def test_fk_caching(self):
c1 = Category.create(name='c1')
c2 = Category.create(name='c2', parent=c1)
c2_db = Category.get(Category.id == c2.id)
qc = len(self.queries())
parent = c2_db.parent
self.assertEqual(parent, c1)
parent = c2_db.parent
self.assertEqual(len(self.queries()) - qc, 1)
def test_creation(self):
self.create_users(10)
self.assertEqual(User.select().count(), 10)
def test_saving(self):
self.assertEqual(User.select().count(), 0)
u = User(username='u1')
u.save()
u.save()
self.assertEqual(User.select().count(), 1)
def test_saving_via_create_gh111(self):
u = User.create(username='u')
b = Blog.create(title='foo', user=u)
last_sql, _ = self.queries()[-1]
self.assertFalse('pub_date' in last_sql)
self.assertEqual(b.pub_date, None)
b2 = Blog(title='foo2', user=u)
b2.save()
last_sql, _ = self.queries()[-1]
self.assertFalse('pub_date' in last_sql)
self.assertEqual(b2.pub_date, None)
def test_reading(self):
u1 = self.create_user('u1')
u2 = self.create_user('u2')
self.assertEqual(u1, User.get(username='u1'))
self.assertEqual(u2, User.get(username='u2'))
self.assertFalse(u1 == u2)
self.assertEqual(u1, User.get(User.username == 'u1'))
self.assertEqual(u2, User.get(User.username == 'u2'))
def test_get_or_create(self):
u1 = User.get_or_create(username='u1')
u1_x = User.get_or_create(username='u1')
self.assertEqual(u1.id, u1_x.id)
self.assertEqual(User.select().count(), 1)
def test_deleting(self):
u1 = self.create_user('u1')
u2 = self.create_user('u2')
self.assertEqual(User.select().count(), 2)
u1.delete_instance()
self.assertEqual(User.select().count(), 1)
self.assertEqual(u2, User.get(username='u2'))
def test_counting(self):
u1 = User.create(username='u1')
u2 = User.create(username='u2')
for u in [u1, u2]:
for i in range(5):
Blog.create(title='b-%s-%s' % (u.username, i), user=u)
uc = User.select().where(User.username == 'u1').join(Blog).count()
self.assertEqual(uc, 5)
uc = User.select().where(User.username == 'u1').join(Blog).distinct().count()
self.assertEqual(uc, 1)
def test_count_transaction(self):
for i in range(10):
self.create_user(username='u%d' % i)
with transaction(test_db):
for user in SelectQuery(User):
for i in range(20):
Blog.create(user=user, title='b-%d-%d' % (user.id, i))
count = SelectQuery(Blog).count()
self.assertEqual(count, 200)
def test_exists(self):
u1 = User.create(username='u1')
self.assertTrue(User.select().where(User.username == 'u1').exists())
self.assertFalse(User.select().where(User.username == 'u2').exists())
def test_unicode(self):
ustr = u'Lýðveldið Ísland'
u = self.create_user(username=ustr)
u2 = User.get(User.username == ustr)
self.assertEqual(u2.username, ustr)
class RecursiveDeleteTestCase(BasePeeweeTestCase):
def setUp(self):
super(RecursiveDeleteTestCase, self).setUp()
Parent.create_table(True)
Child.create_table(True)
Orphan.create_table(True)
ChildPet.create_table(True)
OrphanPet.create_table(True)
p1 = Parent.create(data='p1')
p2 = Parent.create(data='p2')
c11 = Child.create(parent=p1)
c12 = Child.create(parent=p1)
c21 = Child.create(parent=p2)
c22 = Child.create(parent=p2)
o11 = Orphan.create(parent=p1)
o12 = Orphan.create(parent=p1)
o21 = Orphan.create(parent=p2)
o22 = Orphan.create(parent=p2)
ChildPet.create(child=c11)
ChildPet.create(child=c12)
ChildPet.create(child=c21)
ChildPet.create(child=c22)
OrphanPet.create(orphan=o11)
OrphanPet.create(orphan=o12)
OrphanPet.create(orphan=o21)
OrphanPet.create(orphan=o22)
self.p1 = p1
self.p2 = p2
def tearDown(self):
super(RecursiveDeleteTestCase, self).tearDown()
OrphanPet.drop_table()
ChildPet.drop_table()
Orphan.drop_table()
Child.drop_table()
Parent.drop_table()
def test_recursive_update(self):
self.p1.delete_instance(recursive=True)
counts = (
#query,fk,p1,p2,tot
(Child.select(), Child.parent, 0, 2, 2),
(Orphan.select(), Orphan.parent, 0, 2, 4),
(ChildPet.select().join(Child), Child.parent, 0, 2, 2),
(OrphanPet.select().join(Orphan), Orphan.parent, 0, 2, 4),
)
for query, fk, p1_ct, p2_ct, tot in counts:
self.assertEqual(query.where(fk == self.p1).count(), p1_ct)
self.assertEqual(query.where(fk == self.p2).count(), p2_ct)
self.assertEqual(query.count(), tot)
def test_recursive_delete(self):
self.p1.delete_instance(recursive=True, delete_nullable=True)
counts = (
#query,fk,p1,p2,tot
(Child.select(), Child.parent, 0, 2, 2),
(Orphan.select(), Orphan.parent, 0, 2, 2),
(ChildPet.select().join(Child), Child.parent, 0, 2, 2),
(OrphanPet.select().join(Orphan), Orphan.parent, 0, 2, 2),
)
for query, fk, p1_ct, p2_ct, tot in counts:
self.assertEqual(query.where(fk == self.p1).count(), p1_ct)
self.assertEqual(query.where(fk == self.p2).count(), p2_ct)
self.assertEqual(query.count(), tot)
class MultipleFKTestCase(ModelTestCase):
requires = [User, Relationship]
def test_multiple_fks(self):
a = User.create(username='a')
b = User.create(username='b')
c = User.create(username='c')
self.assertEqual(list(a.relationships), [])
self.assertEqual(list(a.related_to), [])
r_ab = Relationship.create(from_user=a, to_user=b)
self.assertEqual(list(a.relationships), [r_ab])
self.assertEqual(list(a.related_to), [])
self.assertEqual(list(b.relationships), [])
self.assertEqual(list(b.related_to), [r_ab])
r_bc = Relationship.create(from_user=b, to_user=c)
following = User.select().join(
Relationship, on=Relationship.to_user
).where(Relationship.from_user == a)
self.assertEqual(list(following), [b])
followers = User.select().join(
Relationship, on=Relationship.from_user
).where(Relationship.to_user == a.id)
self.assertEqual(list(followers), [])
following = User.select().join(
Relationship, on=Relationship.to_user
).where(Relationship.from_user == b.id)
self.assertEqual(list(following), [c])
followers = User.select().join(
Relationship, on=Relationship.from_user
).where(Relationship.to_user == b.id)
self.assertEqual(list(followers), [a])
following = User.select().join(
Relationship, on=Relationship.to_user
).where(Relationship.from_user == c.id)
self.assertEqual(list(following), [])
followers = User.select().join(
Relationship, on=Relationship.from_user
).where(Relationship.to_user == c.id)
self.assertEqual(list(followers), [b])
class ManyToManyTestCase(ModelTestCase):
requires = [User, Category, UserCategory]
def test_m2m(self):
u1 = User.create(username='u1')
u2 = User.create(username='u2')
u3 = User.create(username='u3')
c1 = Category.create(name='c1')
c2 = Category.create(name='c2')
c3 = Category.create(name='c3')
# extras
c12 = Category.create(name='c12')
c23 = Category.create(name='c23')
umap = (
(u1, c1),
(u2, c2),
(u1, c12),
(u2, c12),
(u2, c23),
)
for u, c in umap:
UserCategory.create(user=u, category=c)
def aU(q, exp):
self.assertEqual([u.username for u in q.order_by(User.username)], exp)
def aC(q, exp):
self.assertEqual([c.name for c in q.order_by(Category.name)], exp)
users = User.select().join(UserCategory).join(Category).where(Category.name == 'c1')
aU(users, ['u1'])
users = User.select().join(UserCategory).join(Category).where(Category.name == 'c3')
aU(users, [])
cats = Category.select().join(UserCategory).join(User).where(User.username == 'u1')
aC(cats, ['c1', 'c12'])
cats = Category.select().join(UserCategory).join(User).where(User.username == 'u2')
aC(cats, ['c12', 'c2', 'c23'])
cats = Category.select().join(UserCategory).join(User).where(User.username == 'u3')
aC(cats, [])
cats = Category.select().join(UserCategory).join(User).where(
Category.name << ['c1', 'c2', 'c3']
)
aC(cats, ['c1', 'c2'])
cats = Category.select().join(UserCategory, JOIN_LEFT_OUTER).join(User, JOIN_LEFT_OUTER).where(
Category.name << ['c1', 'c2', 'c3']
)
aC(cats, ['c1', 'c2', 'c3'])
class FieldTypeTestCase(ModelTestCase):
requires = [NullModel]
_dt = datetime.datetime
_d = datetime.date
_t = datetime.time
_data = (
('char_field', 'text_field', 'int_field', 'float_field', 'decimal_field1', 'datetime_field', 'date_field', 'time_field'),
('c1', 't1', 1, 1.0, "1.0", _dt(2010, 1, 1), _d(2010, 1, 1), _t(1, 0)),
('c2', 't2', 2, 2.0, "2.0", _dt(2010, 1, 2), _d(2010, 1, 2), _t(2, 0)),
('c3', 't3', 3, 3.0, "3.0", _dt(2010, 1, 3), _d(2010, 1, 3), _t(3, 0)),
)
def setUp(self):
super(FieldTypeTestCase, self).setUp()
self.field_data = {}
headers = self._data[0]
for row in self._data[1:]:
nm = NullModel()
for i, col in enumerate(row):
attr = headers[i]
self.field_data.setdefault(attr, [])
self.field_data[attr].append(col)
setattr(nm, attr, col)
nm.save()
def assertNM(self, q, exp):
query = NullModel.select().where(q).order_by(NullModel.id)
self.assertEqual([nm.char_field for nm in query], exp)
def test_field_types(self):
for field, values in self.field_data.items():
field_obj = getattr(NullModel, field)
self.assertNM(field_obj < values[2], ['c1', 'c2'])
self.assertNM(field_obj <= values[1], ['c1', 'c2'])
self.assertNM(field_obj > values[0], ['c2', 'c3'])
self.assertNM(field_obj >= values[1], ['c2', 'c3'])
self.assertNM(field_obj == values[1], ['c2'])
self.assertNM(field_obj != values[1], ['c1', 'c3'])
self.assertNM(field_obj << [values[0], values[2]], ['c1', 'c3'])
self.assertNM(field_obj << [values[1]], ['c2'])
def test_charfield(self):
nm = NullModel.create(char_field=4)
nm_db = NullModel.get(id=nm.id)
self.assertEqual(nm_db.char_field, '4')
def test_intfield(self):
nm = NullModel.create(int_field='4')
nm_db = NullModel.get(id=nm.id)
self.assertEqual(nm_db.int_field, 4)
def test_floatfield(self):
nm = NullModel.create(float_field='4.2')
nm_db = NullModel.get(id=nm.id)
self.assertEqual(nm_db.float_field, 4.2)
def test_decimalfield(self):
D = decimal.Decimal
nm = NullModel()
nm.decimal_field1 = D("3.14159265358979323")
nm.decimal_field2 = D("100.33")
nm.save()
nm_from_db = NullModel.get(id=nm.id)
# sqlite doesn't enforce these constraints properly
#self.assertEqual(nm_from_db.decimal_field1, decimal.Decimal("3.14159"))
self.assertEqual(nm_from_db.decimal_field2, D("100.33"))
class TestDecimalModel(TestModel):
df1 = DecimalField(decimal_places=2, auto_round=True)
df2 = DecimalField(decimal_places=2, auto_round=True, rounding=decimal.ROUND_UP)
f1 = TestDecimalModel.df1.db_value
f2 = TestDecimalModel.df2.db_value
self.assertEqual(f1(D('1.2345')), D('1.23'))
self.assertEqual(f2(D('1.2345')), D('1.24'))
def test_boolfield(self):
NullModel.delete().execute()
nmt = NullModel.create(boolean_field=True, char_field='t')
nmf = NullModel.create(boolean_field=False, char_field='f')
nmn = NullModel.create(boolean_field=None, char_field='n')
self.assertNM(NullModel.boolean_field == True, ['t'])
self.assertNM(NullModel.boolean_field == False, ['f'])
self.assertNM(NullModel.boolean_field >> None, ['n'])
def test_date_and_time_fields(self):
dt1 = datetime.datetime(2011, 1, 2, 11, 12, 13, 54321)
dt2 = datetime.datetime(2011, 1, 2, 11, 12, 13)
d1 = datetime.date(2011, 1, 3)
t1 = datetime.time(11, 12, 13, 54321)
t2 = datetime.time(11, 12, 13)
nm1 = NullModel.create(datetime_field=dt1, date_field=d1, time_field=t1)
nm2 = NullModel.create(datetime_field=dt2, time_field=t2)
nmf1 = NullModel.get(id=nm1.id)
self.assertEqual(nmf1.date_field, d1)
if BACKEND == 'mysql':
# mysql doesn't store microseconds
self.assertEqual(nmf1.datetime_field, dt2)
self.assertEqual(nmf1.time_field, t2)
else:
self.assertEqual(nmf1.datetime_field, dt1)
self.assertEqual(nmf1.time_field, t1)
nmf2 = NullModel.get(id=nm2.id)
self.assertEqual(nmf2.datetime_field, dt2)
self.assertEqual(nmf2.time_field, t2)
def test_various_formats(self):
class FormatModel(Model):
dtf = DateTimeField()
df = DateField()
tf = TimeField()
dtf = FormatModel._meta.fields['dtf']
df = FormatModel._meta.fields['df']
tf = FormatModel._meta.fields['tf']
d = datetime.datetime
self.assertEqual(dtf.python_value('2012-01-01 11:11:11.123456'), d(
2012, 1, 1, 11, 11, 11, 123456
))
self.assertEqual(dtf.python_value('2012-01-01 11:11:11'), d(
2012, 1, 1, 11, 11, 11
))
self.assertEqual(dtf.python_value('2012-01-01'), d(
2012, 1, 1,
))
self.assertEqual(dtf.python_value('2012 01 01'), '2012 01 01')
d = datetime.date
self.assertEqual(df.python_value('2012-01-01 11:11:11.123456'), d(
2012, 1, 1,
))
self.assertEqual(df.python_value('2012-01-01 11:11:11'), d(
2012, 1, 1,
))
self.assertEqual(df.python_value('2012-01-01'), d(
2012, 1, 1,
))
self.assertEqual(df.python_value('2012 01 01'), '2012 01 01')
t = datetime.time
self.assertEqual(tf.python_value('2012-01-01 11:11:11.123456'), t(
11, 11, 11, 123456
))
self.assertEqual(tf.python_value('2012-01-01 11:11:11'), t(
11, 11, 11
))
self.assertEqual(tf.python_value('11:11:11.123456'), t(
11, 11, 11, 123456
))
self.assertEqual(tf.python_value('11:11:11'), t(
11, 11, 11
))
self.assertEqual(tf.python_value('11:11'), t(
11, 11,
))
self.assertEqual(tf.python_value('11:11 AM'), '11:11 AM')
class CustomFormatsModel(Model):
dtf = DateTimeField(formats=['%b %d, %Y %I:%M:%S %p'])
df = DateField(formats=['%b %d, %Y'])
tf = TimeField(formats=['%I:%M %p'])
dtf = CustomFormatsModel._meta.fields['dtf']
df = CustomFormatsModel._meta.fields['df']
tf = CustomFormatsModel._meta.fields['tf']
d = datetime.datetime
self.assertEqual(dtf.python_value('2012-01-01 11:11:11.123456'), '2012-01-01 11:11:11.123456')
self.assertEqual(dtf.python_value('Jan 1, 2012 11:11:11 PM'), d(
2012, 1, 1, 23, 11, 11,
))
d = datetime.date
self.assertEqual(df.python_value('2012-01-01'), '2012-01-01')
self.assertEqual(df.python_value('Jan 1, 2012'), d(
2012, 1, 1,
))
t = datetime.time
self.assertEqual(tf.python_value('11:11:11'), '11:11:11')
self.assertEqual(tf.python_value('11:11 PM'), t(
23, 11
))
class UniqueTestCase(ModelTestCase):
requires = [UniqueModel, MultiIndexModel]
def test_unique(self):
uniq1 = UniqueModel.create(name='a')
uniq2 = UniqueModel.create(name='b')
self.assertRaises(Exception, UniqueModel.create, name='a')
test_db.rollback()
def test_multi_index(self):
mi1 = MultiIndexModel.create(f1='a', f2='a', f3='a')
mi2 = MultiIndexModel.create(f1='b', f2='b', f3='b')
self.assertRaises(Exception, MultiIndexModel.create, f1='a', f2='a', f3='b')
test_db.rollback()
self.assertRaises(Exception, MultiIndexModel.create, f1='b', f2='b', f3='a')
test_db.rollback()
mi3 = MultiIndexModel.create(f1='a', f2='b', f3='b')
class NonIntPKTestCase(ModelTestCase):
requires = [NonIntModel, NonIntRelModel]
def test_non_int_pk(self):
ni1 = NonIntModel.create(pk='a1', data='ni1')
self.assertEqual(ni1.pk, 'a1')
ni2 = NonIntModel(pk='a2', data='ni2')
ni2.save(force_insert=True)
self.assertEqual(ni2.pk, 'a2')
ni2.save()
self.assertEqual(ni2.pk, 'a2')
self.assertEqual(NonIntModel.select().count(), 2)
ni1_db = NonIntModel.get(pk='a1')
self.assertEqual(ni1_db.data, ni1.data)
self.assertEqual([(x.pk, x.data) for x in NonIntModel.select().order_by(NonIntModel.pk)], [
('a1', 'ni1'), ('a2', 'ni2'),
])
def test_non_int_fk(self):
ni1 = NonIntModel.create(pk='a1', data='ni1')
ni2 = NonIntModel.create(pk='a2', data='ni2')
rni11 = NonIntRelModel(non_int_model=ni1)
rni12 = NonIntRelModel(non_int_model=ni1)
rni11.save()
rni12.save()
self.assertEqual([r.id for r in ni1.nr.order_by(NonIntRelModel.id)], [rni11.id, rni12.id])
self.assertEqual([r.id for r in ni2.nr.order_by(NonIntRelModel.id)], [])
rni21 = NonIntRelModel.create(non_int_model=ni2)
self.assertEqual([r.id for r in ni2.nr.order_by(NonIntRelModel.id)], [rni21.id])
sq = NonIntRelModel.select().join(NonIntModel).where(NonIntModel.data == 'ni2')
self.assertEqual([r.id for r in sq], [rni21.id])
class DBColumnTestCase(ModelTestCase):
requires = [DBUser, DBBlog]
def test_select(self):
sq = DBUser.select().where(DBUser.username == 'u1')
self.assertSelect(sq, 'dbuser."db_user_id", dbuser."db_username"', [])
self.assertWhere(sq, 'dbuser."db_username" = ?', ['u1'])
sq = DBUser.select(DBUser.user_id).join(DBBlog).where(DBBlog.title == 'b1')
self.assertSelect(sq, 'dbuser."db_user_id"', [])
self.assertJoins(sq, ['INNER JOIN "dbblog" AS dbblog ON dbuser."db_user_id" = dbblog."db_user"'])
self.assertWhere(sq, 'dbblog."db_title" = ?', ['b1'])
def test_db_column(self):
u1 = DBUser.create(username='u1')
u2 = DBUser.create(username='u2')
u2_db = DBUser.get(user_id=u2.get_id())
self.assertEqual(u2_db.username, 'u2')
b1 = DBBlog.create(user=u1, title='b1')
b2 = DBBlog.create(user=u2, title='b2')
b2_db = DBBlog.get(blog_id=b2.get_id())
self.assertEqual(b2_db.user.user_id, u2.user_id)
self.assertEqual(b2_db.title, 'b2')
self.assertEqual([b.title for b in u2.dbblog_set], ['b2'])
class TransactionTestCase(ModelTestCase):
requires = [User, Blog]
def tearDown(self):
super(TransactionTestCase, self).tearDown()
test_db.set_autocommit(True)
def test_autocommit(self):
test_db.set_autocommit(False)
u1 = User.create(username='u1')
u2 = User.create(username='u2')
# open up a new connection to the database, it won't register any blogs
# as being created
new_db = database_class(database_name)
res = new_db.execute_sql('select count(*) from users;')
self.assertEqual(res.fetchone()[0], 0)
# commit our blog inserts
test_db.commit()
# now the blogs are query-able from another connection
res = new_db.execute_sql('select count(*) from users;')
self.assertEqual(res.fetchone()[0], 2)
def test_commit_on_success(self):
self.assertTrue(test_db.get_autocommit())
@test_db.commit_on_success
def will_fail():
u = User.create(username='u1')
b = Blog.create() # no blog, will raise an error
return u, b
self.assertRaises(Exception, will_fail)
self.assertEqual(User.select().count(), 0)
self.assertEqual(Blog.select().count(), 0)
@test_db.commit_on_success
def will_succeed():
u = User.create(username='u1')
b = Blog.create(title='b1', user=u)
return u, b
u, b = will_succeed()
self.assertEqual(User.select().count(), 1)
self.assertEqual(Blog.select().count(), 1)
def test_context_mgr(self):
def will_fail():
u = User.create(username='u1')
b = Blog.create() # no blog, will raise an error
return u, b
def do_will_fail():
with transaction(test_db):
will_fail()
def do_will_fail2():
with test_db.transaction():
will_fail()
self.assertRaises(Exception, do_will_fail)
self.assertEqual(Blog.select().count(), 0)
self.assertRaises(Exception, do_will_fail2)
self.assertEqual(Blog.select().count(), 0)
def will_succeed():
u = User.create(username='u1')
b = Blog.create(title='b1', user=u)
return u, b
def do_will_succeed():
with transaction(test_db):
will_succeed()
def do_will_succeed2():
with test_db.transaction():
will_succeed()
do_will_succeed()
self.assertEqual(User.select().count(), 1)
self.assertEqual(Blog.select().count(), 1)
do_will_succeed2()
self.assertEqual(User.select().count(), 2)
self.assertEqual(Blog.select().count(), 2)
class ConcurrencyTestCase(ModelTestCase):
requires = [User]
def setUp(self):
self._orig_db = test_db
User._meta.database = database_class(database_name, threadlocals=True)
super(ConcurrencyTestCase, self).setUp()
def tearDown(self):
User._meta.database = self._orig_db
super(ConcurrencyTestCase, self).tearDown()
def test_multiple_writers(self):
def create_user_thread(low, hi):
for i in range(low, hi):
User.create(username='u%d' % i)
User._meta.database.close()
threads = []
for i in range(5):
threads.append(threading.Thread(target=create_user_thread, args=(i*10, i * 10 + 10)))
[t.start() for t in threads]
[t.join() for t in threads]
self.assertEqual(User.select().count(), 50)
def test_multiple_readers(self):
data_queue = Queue.Queue()
def reader_thread(q, num):
for i in range(num):
data_queue.put(User.select().count())
threads = []
for i in range(5):
threads.append(threading.Thread(target=reader_thread, args=(data_queue, 20)))
[t.start() for t in threads]
[t.join() for t in threads]
self.assertEqual(data_queue.qsize(), 100)
class ModelOptionInheritanceTestCase(BasePeeweeTestCase):
def test_db_table(self):
self.assertEqual(User._meta.db_table, 'users')
class Foo(TestModel):
pass
self.assertEqual(Foo._meta.db_table, 'foo')
class Foo2(TestModel):
pass
self.assertEqual(Foo2._meta.db_table, 'foo2')
class Foo_3(TestModel):
pass
self.assertEqual(Foo_3._meta.db_table, 'foo_3')
def test_option_inheritance(self):
x_test_db = SqliteDatabase('testing.db')
child2_db = SqliteDatabase('child2.db')
class FakeUser(Model):
pass
class ParentModel(Model):
title = CharField()
user = ForeignKeyField(FakeUser)
class Meta:
database = x_test_db
class ChildModel(ParentModel):
pass
class ChildModel2(ParentModel):
special_field = CharField()
class Meta:
database = child2_db
class GrandChildModel(ChildModel):
pass
class GrandChildModel2(ChildModel2):
special_field = TextField()
self.assertEqual(ParentModel._meta.database.database, 'testing.db')
self.assertEqual(ParentModel._meta.model_class, ParentModel)
self.assertEqual(ChildModel._meta.database.database, 'testing.db')
self.assertEqual(ChildModel._meta.model_class, ChildModel)
self.assertEqual(sorted(ChildModel._meta.fields.keys()), [
'id', 'title', 'user'
])
self.assertEqual(ChildModel2._meta.database.database, 'child2.db')
self.assertEqual(ChildModel2._meta.model_class, ChildModel2)
self.assertEqual(sorted(ChildModel2._meta.fields.keys()), [
'id', 'special_field', 'title', 'user'
])
self.assertEqual(GrandChildModel._meta.database.database, 'testing.db')
self.assertEqual(GrandChildModel._meta.model_class, GrandChildModel)
self.assertEqual(sorted(GrandChildModel._meta.fields.keys()), [
'id', 'title', 'user'
])
self.assertEqual(GrandChildModel2._meta.database.database, 'child2.db')
self.assertEqual(GrandChildModel2._meta.model_class, GrandChildModel2)
self.assertEqual(sorted(GrandChildModel2._meta.fields.keys()), [
'id', 'special_field', 'title', 'user'
])
self.assertTrue(isinstance(GrandChildModel2._meta.fields['special_field'], TextField))
class ModelInheritanceTestCase(ModelTestCase):
requires = [Blog, BlogTwo, User]
def test_model_inheritance_attrs(self):
self.assertEqual(Blog._meta.get_field_names(), ['pk', 'user', 'title', 'content', 'pub_date'])
self.assertEqual(BlogTwo._meta.get_field_names(), ['id', 'user', 'content', 'pub_date', 'title', 'extra_field'])
self.assertEqual(Blog._meta.primary_key.name, 'pk')
self.assertEqual(BlogTwo._meta.primary_key.name, 'id')
self.assertEqual(Blog.user.related_name, 'blog_set')
self.assertEqual(BlogTwo.user.related_name, 'blogtwo_set')
self.assertEqual(User.blog_set.rel_model, Blog)
self.assertEqual(User.blogtwo_set.rel_model, BlogTwo)
self.assertFalse(BlogTwo._meta.db_table == Blog._meta.db_table)
def test_model_inheritance_flow(self):
u = User.create(username='u')
b = Blog.create(title='b', user=u)
b2 = BlogTwo.create(title='b2', extra_field='foo', user=u)
self.assertEqual(list(u.blog_set), [b])
self.assertEqual(list(u.blogtwo_set), [b2])
self.assertEqual(Blog.select().count(), 1)
self.assertEqual(BlogTwo.select().count(), 1)
b_from_db = Blog.get(pk=b.pk)
b2_from_db = BlogTwo.get(id=b2.id)
self.assertEqual(b_from_db.user, u)
self.assertEqual(b2_from_db.user, u)
self.assertEqual(b2_from_db.extra_field, 'foo')
class DatabaseTestCase(BasePeeweeTestCase):
def test_deferred_database(self):
deferred_db = SqliteDatabase(None)
self.assertTrue(deferred_db.deferred)
class DeferredModel(Model):
class Meta:
database = deferred_db
self.assertRaises(Exception, deferred_db.connect)
sq = DeferredModel.select()
self.assertRaises(Exception, sq.execute)
deferred_db.init(':memory:')
self.assertFalse(deferred_db.deferred)
# connecting works
conn = deferred_db.connect()
DeferredModel.create_table()
sq = DeferredModel.select()
self.assertEqual(list(sq), [])
deferred_db.init(None)
self.assertTrue(deferred_db.deferred)
class ConnectionStateTestCase(BasePeeweeTestCase):
def test_connection_state(self):
conn = test_db.get_conn()
self.assertFalse(test_db.is_closed())
test_db.close()
self.assertTrue(test_db.is_closed())
conn = test_db.get_conn()
self.assertFalse(test_db.is_closed())
class TopologicalSortTestCase(unittest.TestCase):
def test_topological_sort_fundamentals(self):
FKF = ForeignKeyField
# we will be topo-sorting the following models
class A(Model): pass
class B(Model): a = FKF(A) # must follow A
class C(Model): a, b = FKF(A), FKF(B) # must follow A and B
class D(Model): c = FKF(C) # must follow A and B and C
class E(Model): e = FKF('self')
# but excluding this model, which is a child of E
class Excluded(Model): e = FKF(E)
# property 1: output ordering must not depend upon input order
repeatable_ordering = None
for input_ordering in permutations([A, B, C, D, E]):
output_ordering = sort_models_topologically(input_ordering)
repeatable_ordering = repeatable_ordering or output_ordering
self.assertEqual(repeatable_ordering, output_ordering)
# property 2: output ordering must have same models as input
self.assertEqual(len(output_ordering), 5)
self.assertFalse(Excluded in output_ordering)
# property 3: parents must precede children
def assert_precedes(X, Y):
lhs, rhs = map(output_ordering.index, [X, Y])
self.assertTrue(lhs < rhs)
assert_precedes(A, B)
assert_precedes(B, C) # if true, C follows A by transitivity
assert_precedes(C, D) # if true, D follows A and B by transitivity
# property 4: independent model hierarchies must be in name order
assert_precedes(A, E)
def permutations(xs):
if not xs:
yield []
else:
for y, ys in selections(xs):
for pys in permutations(ys):
yield [y] + pys
def selections(xs):
for i in xrange(len(xs)):
yield (xs[i], xs[:i] + xs[i + 1:])
if test_db.for_update:
class ForUpdateTestCase(ModelTestCase):
requires = [User]
def tearDown(self):
test_db.set_autocommit(True)
def test_for_update(self):
u1 = self.create_user('u1')
u2 = self.create_user('u2')
u3 = self.create_user('u3')
test_db.set_autocommit(False)
# select a user for update
users = User.select().where(User.username == 'u1').for_update()
updated = User.update(username='u1_edited').where(User.username == 'u1').execute()
self.assertEqual(updated, 1)
# open up a new connection to the database
new_db = database_class(database_name)
# select the username, it will not register as being updated
res = new_db.execute_sql('select username from users where id = %s;' % u1.id)
username = res.fetchone()[0]
self.assertEqual(username, 'u1')
# committing will cause the lock to be released
test_db.commit()
# now we get the update
res = new_db.execute_sql('select username from users where id = %s;' % u1.id)
username = res.fetchone()[0]
self.assertEqual(username, 'u1_edited')
elif TEST_VERBOSITY > 0:
print 'Skipping "for update" tests'
if test_db.sequences:
class SequenceTestCase(ModelTestCase):
requires = [SeqModelA, SeqModelB]
def test_sequence_shared(self):
a1 = SeqModelA.create(num=1)
a2 = SeqModelA.create(num=2)
b1 = SeqModelB.create(other_num=101)
b2 = SeqModelB.create(other_num=102)
a3 = SeqModelA.create(num=3)
self.assertEqual(a1.id, a2.id - 1)
self.assertEqual(a2.id, b1.id - 1)
self.assertEqual(b1.id, b2.id - 1)
self.assertEqual(b2.id, a3.id - 1)
elif TEST_VERBOSITY > 0:
print 'Skipping "sequence" tests'
class Job(TestModel):
"""A job that can be queued for later execution."""
name = CharField()
class JobExecutionRecord(TestModel):
"""Record of a job having been executed."""
# the foreign key is also the primary key to enforce the
# constraint that a job can be executed once and only once
job = ForeignKeyField(Job, primary_key=True)
status = CharField()
class PrimaryForeignKeyTestCase(unittest.TestCase):
def setUp(self):
Job.create_table()
JobExecutionRecord.create_table()
def tearDown(self):
JobExecutionRecord.drop_table()
Job.drop_table()
def test_primary_foreign_key(self):
# we have one job, unexecuted, and therefore no executed jobs
job = Job.create(name='Job One')
executed_jobs = Job.select().join(JobExecutionRecord)
self.assertEqual([], list(executed_jobs))
# after execution, we must have one executed job
exec_record = JobExecutionRecord.create(job=job, status='success')
executed_jobs = Job.select().join(JobExecutionRecord)
self.assertEqual([job], list(executed_jobs))
# we must not be able to create another execution record for the job
with self.assertRaises(Exception):
JobExecutionRecord.create(job=job, status='success')
test_db.rollback()
Jump to Line
Something went wrong with that request. Please try again.