Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: share logic for Criterion comparisions #805

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ default_stages: [commit]
fail_fast: false

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- repo: git@github.com:pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
Expand All @@ -19,7 +19,7 @@ repos:
- id: check-yaml
- id: debug-statements

- repo: https://github.com/psf/black
- repo: git@github.com:psf/black
rev: 23.9.1
hooks:
- id: black
33 changes: 15 additions & 18 deletions pypika/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ def wrap_constant(

"""

if isinstance(val, Node):
return val
if val is None:
return NullValue()
if isinstance(val, Node):
return val
if isinstance(val, list):
return Array(*val)
if isinstance(val, tuple):
Expand All @@ -94,10 +94,10 @@ def wrap_json(
) -> Union["Term", "QueryBuilder", "Interval", "NullValue", "ValueWrapper", "JSON"]:
from .queries import QueryBuilder

if isinstance(val, (Term, QueryBuilder, Interval)):
return val
if val is None:
return NullValue()
if isinstance(val, (Term, QueryBuilder, Interval)):
return val
if isinstance(val, (str, int, bool)):
wrapper_cls = wrapper_cls or ValueWrapper
return wrapper_cls(val)
Expand Down Expand Up @@ -316,12 +316,9 @@ def get_sql(self, **kwargs: Any) -> str:
return ":{placeholder}".format(placeholder=self.placeholder)


class NamedParameter(Parameter):
class NamedParameter(NumericParameter):
"""Named style, e.g. ...WHERE name=:name"""

def get_sql(self, **kwargs: Any) -> str:
return ":{placeholder}".format(placeholder=self.placeholder)


class FormatParameter(Parameter):
"""ANSI C printf format codes, e.g. ...WHERE name=%s"""
Expand Down Expand Up @@ -365,6 +362,8 @@ def get_value_sql(self, **kwargs: Any) -> str:

@classmethod
def get_formatted_value(cls, value: Any, **kwargs):
if value is None:
return "null"
quote_char = kwargs.get("secondary_quote_char") or ""

# FIXME escape values
Expand All @@ -381,8 +380,6 @@ def get_formatted_value(cls, value: Any, **kwargs):
return str.lower(str(value))
if isinstance(value, uuid.UUID):
return cls.get_formatted_value(str(value), **kwargs)
if value is None:
return "null"
return str(value)

def get_sql(self, quote_char: Optional[str] = None, secondary_quote_char: str = "'", **kwargs: Any) -> str:
Expand Down Expand Up @@ -485,20 +482,19 @@ def __init__(self, alias: Optional[str] = None) -> None:


class Criterion(Term):
def __and__(self, other: Any) -> "ComplexCriterion":
def _compare(self, comparator: Comparator, other: Any) -> "ComplexCriterion":
if isinstance(other, EmptyCriterion):
return self
return ComplexCriterion(Boolean.and_, self, other)
return ComplexCriterion(comparator, self, other)

def __and__(self, other: Any) -> "ComplexCriterion":
return self._compare(Boolean.and_, other)

def __or__(self, other: Any) -> "ComplexCriterion":
if isinstance(other, EmptyCriterion):
return self
return ComplexCriterion(Boolean.or_, self, other)
return self._compare(Boolean.or_, other)

def __xor__(self, other: Any) -> "ComplexCriterion":
if isinstance(other, EmptyCriterion):
return self
return ComplexCriterion(Boolean.xor_, self, other)
return self._compare(Boolean.xor_, other)

@staticmethod
def any(terms: Iterable[Term] = ()) -> "EmptyCriterion":
Expand Down Expand Up @@ -551,6 +547,7 @@ def __init__(
if isinstance(table, str):
# avoid circular import at load time
from pypika.queries import Table

table = Table(table)
self.table = table

Expand Down