Skip to content

Commit

Permalink
Merge ce5856d into cb8ae12
Browse files Browse the repository at this point in the history
  • Loading branch information
squareapartments committed Jul 16, 2021
2 parents cb8ae12 + ce5856d commit 5303972
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 4 deletions.
56 changes: 55 additions & 1 deletion pypika/dialects.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import itertools
from copy import copy
from typing import Any, Optional, Union
from typing import Any, Optional, Union, Tuple as TypedTuple

from pypika.enums import Dialects
from pypika.queries import (
Expand Down Expand Up @@ -84,12 +84,25 @@ def __init__(self, **kwargs: Any) -> None:
self._ignore_duplicates = False
self._modifiers = []

self._for_update_nowait = False
self._for_update_skip_locked = False
self._for_update_of = set()

def __copy__(self) -> "MySQLQueryBuilder":
newone = super().__copy__()
newone._duplicate_updates = copy(self._duplicate_updates)
newone._ignore_duplicates = copy(self._ignore_duplicates)
return newone

@builder
def for_update(
self, nowait: bool = False, skip_locked: bool = False, of: TypedTuple[str, ...] = ()
) -> "QueryBuilder":
self._for_update = True
self._for_update_skip_locked = skip_locked
self._for_update_nowait = nowait
self._for_update_of = set(of)

@builder
def on_duplicate_key_update(self, field: Union[Field, str], value: Any) -> "MySQLQueryBuilder":
if self._ignore_duplicates:
Expand All @@ -115,6 +128,20 @@ def get_sql(self, **kwargs: Any) -> str:
querystring += self._on_duplicate_key_ignore_sql()
return querystring

def _for_update_sql(self, **kwargs) -> str:
if self._for_update:
for_update = ' FOR UPDATE'
if self._for_update_of:
for_update += f' OF {", ".join([Table(item).get_sql(**kwargs) for item in self._for_update_of])}'
if self._for_update_nowait:
for_update += ' NOWAIT'
elif self._for_update_skip_locked:
for_update += ' SKIP LOCKED'
else:
for_update = ''

return for_update

def _on_duplicate_key_update_sql(self, **kwargs: Any) -> str:
return " ON DUPLICATE KEY UPDATE {updates}".format(
updates=",".join(
Expand Down Expand Up @@ -356,6 +383,10 @@ def __init__(self, **kwargs: Any) -> None:

self._distinct_on = []

self._for_update_nowait = False
self._for_update_skip_locked = False
self._for_update_of = set()

def __copy__(self) -> "PostgreSQLQueryBuilder":
newone = super().__copy__()
newone._returns = copy(self._returns)
Expand All @@ -370,6 +401,15 @@ def distinct_on(self, *fields: Union[str, Term]) -> "PostgreSQLQueryBuilder":
elif isinstance(field, Term):
self._distinct_on.append(field)

@builder
def for_update(
self, nowait: bool = False, skip_locked: bool = False, of: TypedTuple[str, ...] = ()
) -> "QueryBuilder":
self._for_update = True
self._for_update_skip_locked = skip_locked
self._for_update_nowait = nowait
self._for_update_of = set(of)

@builder
def on_conflict(self, *target_fields: Union[str, Term]) -> "PostgreSQLQueryBuilder":
if not self._insert_table:
Expand Down Expand Up @@ -466,6 +506,20 @@ def _on_conflict_sql(self, **kwargs: Any) -> str:

return conflict_query

def _for_update_sql(self, **kwargs) -> str:
if self._for_update:
for_update = ' FOR UPDATE'
if self._for_update_of:
for_update += f' OF {", ".join([Table(item).get_sql(**kwargs) for item in self._for_update_of])}'
if self._for_update_nowait:
for_update += ' NOWAIT'
elif self._for_update_skip_locked:
for_update += ' SKIP LOCKED'
else:
for_update = ''

return for_update

def _on_conflict_action_sql(self, **kwargs: Any) -> str:
if self._on_conflict_do_nothing:
return " DO NOTHING"
Expand Down
7 changes: 4 additions & 3 deletions pypika/queries.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from copy import copy
from functools import reduce
from typing import Any, List, Optional, Sequence, Tuple as TypedTuple, Type, Union
from typing import Any, List, Optional, Sequence, Tuple as TypedTuple, Type, Union, Set

from pypika.enums import Dialects, JoinType, ReferenceOption, SetOperation
from pypika.terms import (
Expand Down Expand Up @@ -704,6 +704,7 @@ def __init__(
self._values = []
self._distinct = False
self._ignore = False

self._for_update = False

self._wheres = None
Expand Down Expand Up @@ -1330,7 +1331,7 @@ def get_sql(self, with_alias: bool = False, subquery: bool = False, **kwargs: An
querystring = self._apply_pagination(querystring)

if self._for_update:
querystring += self._for_update_sql()
querystring += self._for_update_sql(**kwargs)

if subquery:
querystring = "({query})".format(query=querystring)
Expand Down Expand Up @@ -1366,7 +1367,7 @@ def _distinct_sql(self, **kwargs: Any) -> str:

return distinct

def _for_update_sql(self) -> str:
def _for_update_sql(self, **kwargs) -> str:
if self._for_update:
for_update = ' FOR UPDATE'
else:
Expand Down
153 changes: 153 additions & 0 deletions pypika/tests/test_selects.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ class MyEnum(Enum):

class WhereTests(unittest.TestCase):
t = Table("abc")
t2 = Table("cba")

def test_where_enum(self):
q1 = Query.from_(self.t).select("*").where(self.t.foo == MyEnum.STR)
Expand Down Expand Up @@ -378,6 +379,158 @@ def test_where_field_equals_for_update(self):
q = Query.from_(self.t).select("*").where(self.t.foo == self.t.bar).for_update()
self.assertEqual('SELECT * FROM "abc" WHERE "foo"="bar" FOR UPDATE', str(q))

def test_where_field_equals_for_update_nowait(self):
for query_cls in [
MySQLQuery,
PostgreSQLQuery,
]:
quote_char = query_cls._builder().QUOTE_CHAR if isinstance(query_cls._builder().QUOTE_CHAR, str) else '"'
q = query_cls.from_(self.t).select("*").where(self.t.foo == self.t.bar).for_update(nowait=True)
self.assertEqual(
'SELECT * '
'FROM {quote_char}abc{quote_char} '
'WHERE {quote_char}foo{quote_char}={quote_char}bar{quote_char} '
'FOR UPDATE NOWAIT'.format(
quote_char=quote_char,
),
str(q),
)

def test_where_field_equals_for_update_skip_locked(self):
for query_cls in [
MySQLQuery,
PostgreSQLQuery,
]:
quote_char = query_cls._builder().QUOTE_CHAR if isinstance(query_cls._builder().QUOTE_CHAR, str) else '"'
q = query_cls.from_(self.t).select("*").where(self.t.foo == self.t.bar).for_update(skip_locked=True)
self.assertEqual(
'SELECT * '
'FROM {quote_char}abc{quote_char} '
'WHERE {quote_char}foo{quote_char}={quote_char}bar{quote_char} '
'FOR UPDATE SKIP LOCKED'.format(
quote_char=quote_char,
),
str(q),
)

def test_where_field_equals_for_update_of(self):
for query_cls in [
MySQLQuery,
PostgreSQLQuery,
]:
quote_char = query_cls._builder().QUOTE_CHAR if isinstance(query_cls._builder().QUOTE_CHAR, str) else '"'
q = query_cls.from_(self.t).select("*").where(self.t.foo == self.t.bar).for_update(of=("abc",))
self.assertEqual(
'SELECT * '
'FROM {quote_char}abc{quote_char} '
'WHERE {quote_char}foo{quote_char}={quote_char}bar{quote_char} '
'FOR UPDATE OF {quote_char}abc{quote_char}'.format(
quote_char=quote_char,
),
str(q),
)

def test_where_field_equals_for_update_of_multiple_tables(self):
for query_cls in [
MySQLQuery,
PostgreSQLQuery,
]:
q = (
query_cls.from_(self.t)
.join(self.t2)
.on(self.t.id == self.t2.abc_id)
.select("*")
.where(self.t.foo == self.t.bar)
.for_update(of=("abc", "cba"))
)
quote_char = query_cls._builder().QUOTE_CHAR if isinstance(query_cls._builder().QUOTE_CHAR, str) else '"'
self.assertIn(
str(q),
[
'SELECT * '
'FROM {quote_char}abc{quote_char} '
'JOIN {quote_char}cba{quote_char} '
'ON {quote_char}abc{quote_char}.{quote_char}id{quote_char}='
'{quote_char}cba{quote_char}.{quote_char}abc_id{quote_char} '
'WHERE {quote_char}abc{quote_char}.{quote_char}foo{quote_char}='
'{quote_char}abc{quote_char}.{quote_char}bar{quote_char} '
'FOR UPDATE OF {quote_char}cba{quote_char}, {quote_char}abc{quote_char}'.format(
quote_char=quote_char,
),
'SELECT * '
'FROM {quote_char}abc{quote_char} '
'JOIN {quote_char}cba{quote_char} '
'ON {quote_char}abc{quote_char}.{quote_char}id{quote_char}='
'{quote_char}cba{quote_char}.{quote_char}abc_id{quote_char} '
'WHERE {quote_char}abc{quote_char}.{quote_char}foo{quote_char}='
'{quote_char}abc{quote_char}.{quote_char}bar{quote_char} '
'FOR UPDATE OF {quote_char}abc{quote_char}, {quote_char}cba{quote_char}'.format(
quote_char=quote_char,
),
],
)

def test_where_field_equals_for_update_of_nowait(self):
for query_cls in [
MySQLQuery,
PostgreSQLQuery,
]:
q = query_cls.from_(self.t).select("*").where(self.t.foo == self.t.bar).for_update(of=("abc",), nowait=True)
quote_char = query_cls._builder().QUOTE_CHAR if isinstance(query_cls._builder().QUOTE_CHAR, str) else '"'
self.assertEqual(
'SELECT * '
'FROM {quote_char}abc{quote_char} '
'WHERE {quote_char}foo{quote_char}={quote_char}bar{quote_char} '
'FOR UPDATE OF {quote_char}abc{quote_char} NOWAIT'.format(
quote_char=quote_char,
),
str(q),
)

def test_where_field_equals_for_update_of_skip_locked(self):
for query_cls in [
MySQLQuery,
PostgreSQLQuery,
]:
q = (
query_cls.from_(self.t)
.select("*")
.where(self.t.foo == self.t.bar)
.for_update(of=("abc",), skip_locked=True)
)
quote_char = query_cls._builder().QUOTE_CHAR if isinstance(query_cls._builder().QUOTE_CHAR, str) else '"'
self.assertEqual(
'SELECT * '
'FROM {quote_char}abc{quote_char} '
'WHERE {quote_char}foo{quote_char}={quote_char}bar{quote_char} '
'FOR UPDATE OF {quote_char}abc{quote_char} SKIP LOCKED'.format(
quote_char=quote_char,
),
str(q),
)

def test_where_field_equals_for_update_skip_locked_and_of(self):
for query_cls in [
MySQLQuery,
PostgreSQLQuery,
]:
q = (
query_cls.from_(self.t)
.select("*")
.where(self.t.foo == self.t.bar)
.for_update(nowait=False, skip_locked=True, of=("abc",))
)
quote_char = query_cls._builder().QUOTE_CHAR if isinstance(query_cls._builder().QUOTE_CHAR, str) else '"'
self.assertEqual(
'SELECT * '
'FROM {quote_char}abc{quote_char} '
'WHERE {quote_char}foo{quote_char}={quote_char}bar{quote_char} '
'FOR UPDATE OF {quote_char}abc{quote_char} SKIP LOCKED'.format(
quote_char=quote_char,
),
str(q),
)

def test_where_field_equals_where(self):
q = Query.from_(self.t).select("*").where(self.t.foo == 1).where(self.t.bar == self.t.baz)

Expand Down

0 comments on commit 5303972

Please sign in to comment.