Skip to content

Commit

Permalink
Merge 06b4601 into c46a5d9
Browse files Browse the repository at this point in the history
  • Loading branch information
nickpetrovic committed Jun 14, 2021
2 parents c46a5d9 + 06b4601 commit 8ac30c6
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pypika/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ class Arithmetic(Enum):
sub = "-"
mul = "*"
div = "/"
lshift = "<<"
rshift = ">>"


class Comparator(Enum):
Expand Down
18 changes: 18 additions & 0 deletions pypika/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ def bin_regex(self, pattern: str) -> "BasicCriterion":
def negate(self) -> "Not":
return Not(self)

def lshift(self, other: Any) -> "ArithmeticExpression":
return self << other

def rshift(self, other: Any) -> "ArithmeticExpression":
return self >> other

def __invert__(self) -> "Not":
return Not(self)

Expand Down Expand Up @@ -237,6 +243,18 @@ def __rmul__(self, other: Any) -> "ArithmeticExpression":
def __rtruediv__(self, other: Any) -> "ArithmeticExpression":
return ArithmeticExpression(Arithmetic.div, self.wrap_constant(other), self)

def __lshift__(self, other: Any) -> "ArithmeticExpression":
return ArithmeticExpression(Arithmetic.lshift, self, self.wrap_constant(other))

def __rshift__(self, other: Any) -> "ArithmeticExpression":
return ArithmeticExpression(Arithmetic.rshift, self, self.wrap_constant(other))

def __rlshift__(self, other: Any) -> "ArithmeticExpression":
return ArithmeticExpression(Arithmetic.lshift, self.wrap_constant(other), self)

def __rrshift__(self, other: Any) -> "ArithmeticExpression":
return ArithmeticExpression(Arithmetic.rshift, self.wrap_constant(other), self)

def __eq__(self, other: Any) -> "BasicCriterion":
return BasicCriterion(Equality.eq, self, self.wrap_constant(other))

Expand Down
9 changes: 9 additions & 0 deletions pypika/tests/dialects/test_postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,15 @@ def test_json_has_any_keys(self):

self.assertEqual("SELECT * FROM \"abc\" WHERE \"json\"?|ARRAY['dates','imported']", str(q))

def test_subnet_contains_inet(self):
q = (
PostgreSQLQuery.from_(self.table_abc)
.select(self.table_abc.a.lshift(2))
.where(self.table_abc.cidr >> "1.1.1.1")
)

self.assertEqual("SELECT \"a\"<<2 FROM \"abc\" WHERE \"cidr\">>'1.1.1.1'", str(q))


class DistinctOnTests(unittest.TestCase):
table_abc = Table("abc")
Expand Down
17 changes: 17 additions & 0 deletions pypika/tests/test_criterions.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,23 @@ def test__in_unicode(self):
self.assertEqual("\"isin\".\"foo\" IN ('a','b')", str(c2))


class ArithmeticExpressionTests(unittest.TestCase):

def test__lshift(self):
c1 = Field("foo").lshift("1")
c2 = Field("foo").lshift("2")

self.assertEqual("\"foo\"<<'1'", str(c1))
self.assertEqual("\"foo\"<<'2'", str(c2))

def test__rshift(self):
c1 = Field("foo").rshift("1")
c2 = Field("foo").rshift("2")

self.assertEqual("\"foo\">>'1'", str(c1))
self.assertEqual("\"foo\">>'2'", str(c2))


class NotInTests(unittest.TestCase):
t = Table("abc", alias="notin")

Expand Down
42 changes: 42 additions & 0 deletions pypika/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,48 @@ def test__division__right(self):
self.assertEqual('SELECT 1/"a" FROM "abc"', str(q1))
self.assertEqual('SELECT 1/"a" FROM "abc"', str(q2))

def test__leftshift__fields(self):
q1 = Q.from_("abc").select(F("a") << F("b"))
q2 = Q.from_(self.t).select(self.t.a << self.t.b)

self.assertEqual('SELECT "a"<<"b" FROM "abc"', str(q1))
self.assertEqual('SELECT "a"<<"b" FROM "abc"', str(q2))

def test__leftshift__number(self):
q1 = Q.from_("abc").select(F('a') << 2)
q2 = Q.from_(self.t).select(self.t.a << 2)

self.assertEqual('SELECT "a"<<2 FROM "abc"', str(q1))
self.assertEqual('SELECT "a"<<2 FROM "abc"', str(q2))

def test__leftshift__right(self):
q1 = Q.from_("abc").select(1 << F("a"))
q2 = Q.from_(self.t).select(1 << self.t.a)

self.assertEqual('SELECT 1<<"a" FROM "abc"', str(q1))
self.assertEqual('SELECT 1<<"a" FROM "abc"', str(q2))

def test__rightshift__fields(self):
q1 = Q.from_("abc").select(F("a") >> F("b"))
q2 = Q.from_(self.t).select(self.t.a >> self.t.b)

self.assertEqual('SELECT "a">>"b" FROM "abc"', str(q1))
self.assertEqual('SELECT "a">>"b" FROM "abc"', str(q2))

def test__rightshift__number(self):
q1 = Q.from_("abc").select(F('a') >> 2)
q2 = Q.from_(self.t).select(self.t.a >> 2)

self.assertEqual('SELECT "a">>2 FROM "abc"', str(q1))
self.assertEqual('SELECT "a">>2 FROM "abc"', str(q2))

def test__rightshift__right(self):
q1 = Q.from_("abc").select(1 >> F("a"))
q2 = Q.from_(self.t).select(1 >> self.t.a)

self.assertEqual('SELECT 1>>"a" FROM "abc"', str(q1))
self.assertEqual('SELECT 1>>"a" FROM "abc"', str(q2))

def test__complex_op(self):
q1 = Q.from_("abc").select(2 + 1 / F("a") - 5)
q2 = Q.from_(self.t).select(2 + 1 / self.t.a - 5)
Expand Down

0 comments on commit 8ac30c6

Please sign in to comment.