Skip to content

Commit

Permalink
Add Type hints to query.py (#1821)
Browse files Browse the repository at this point in the history
* refactor: add type hints to query.py + type_checking in CI

chore: fix linting

* fix: fix typing for older versions of python

* refactor: add typing to query tests

* chore: add type_check to CI

* fix: fix typing for older python versions

* fix: fix python version for ci

---------

Co-authored-by: Miguel Grinberg <miguel.grinberg@gmail.com>
  • Loading branch information
Caiofcas and miguelgrinberg committed May 16, 2024
1 parent b5435a8 commit e68585b
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 95 deletions.
15 changes: 15 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ jobs:
- name: Lint the code
run: nox -s lint

type_check:
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies
run: |
python3 -m pip install nox
- name: Lint the code
run: nox -s type_check

docs:
runs-on: ubuntu-latest
steps:
Expand Down
6 changes: 4 additions & 2 deletions elasticsearch_dsl/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
# under the License.

import collections.abc
from typing import Dict

from .utils import DslBase


def SF(name_or_sf, **params):
# Incomplete annotation to not break query.py tests
def SF(name_or_sf, **params) -> "ScoreFunction":
# {"script_score": {"script": "_score"}, "filter": {}}
if isinstance(name_or_sf, collections.abc.Mapping):
if params:
Expand Down Expand Up @@ -86,7 +88,7 @@ class ScriptScore(ScoreFunction):
class BoostFactor(ScoreFunction):
name = "boost_factor"

def to_dict(self):
def to_dict(self) -> Dict[str, int]:
d = super().to_dict()
if "value" in d[self.name]:
d[self.name] = d[self.name].pop("value")
Expand Down
97 changes: 74 additions & 23 deletions elasticsearch_dsl/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,73 @@
# under the License.

import collections.abc
from copy import deepcopy
from itertools import chain
from typing import (
Any,
Callable,
ClassVar,
List,
Mapping,
MutableMapping,
Optional,
Protocol,
TypeVar,
Union,
cast,
overload,
)

# 'SF' looks unused but the test suite assumes it's available
# from this module so others are liable to do so as well.
from .function import SF # noqa: F401
from .function import ScoreFunction
from .utils import DslBase

_T = TypeVar("_T")
_M = TypeVar("_M", bound=Mapping[str, Any])

def Q(name_or_query="match_all", **params):

class QProxiedProtocol(Protocol[_T]):
_proxied: _T


@overload
def Q(name_or_query: MutableMapping[str, _M]) -> "Query": ...


@overload
def Q(name_or_query: "Query") -> "Query": ...


@overload
def Q(name_or_query: QProxiedProtocol[_T]) -> _T: ...


@overload
def Q(name_or_query: str = "match_all", **params: Any) -> "Query": ...


def Q(
name_or_query: Union[
str,
"Query",
QProxiedProtocol[_T],
MutableMapping[str, _M],
] = "match_all",
**params: Any,
) -> Union["Query", _T]:
# {"match": {"title": "python"}}
if isinstance(name_or_query, collections.abc.Mapping):
if isinstance(name_or_query, collections.abc.MutableMapping):
if params:
raise ValueError("Q() cannot accept parameters when passing in a dict.")
if len(name_or_query) != 1:
raise ValueError(
'Q() can only accept dict with a single query ({"match": {...}}). '
"Instead it got (%r)" % name_or_query
)
name, params = name_or_query.copy().popitem()
return Query.get_dsl_class(name)(_expand__to_dot=False, **params)
name, q_params = deepcopy(name_or_query).popitem()
return Query.get_dsl_class(name)(_expand__to_dot=False, **q_params)

# MatchAll()
if isinstance(name_or_query, Query):
Expand All @@ -48,7 +94,7 @@ def Q(name_or_query="match_all", **params):

# s.query = Q('filtered', query=s.query)
if hasattr(name_or_query, "_proxied"):
return name_or_query._proxied
return cast(QProxiedProtocol[_T], name_or_query)._proxied

# "match", title="python"
return Query.get_dsl_class(name_or_query)(**params)
Expand All @@ -57,26 +103,31 @@ def Q(name_or_query="match_all", **params):
class Query(DslBase):
_type_name = "query"
_type_shortcut = staticmethod(Q)
name = None
name: ClassVar[Optional[str]] = None

# Add type annotations for methods not defined in every subclass
__ror__: ClassVar[Callable[["Query", "Query"], "Query"]]
__radd__: ClassVar[Callable[["Query", "Query"], "Query"]]
__rand__: ClassVar[Callable[["Query", "Query"], "Query"]]

def __add__(self, other):
def __add__(self, other: "Query") -> "Query":
# make sure we give queries that know how to combine themselves
# preference
if hasattr(other, "__radd__"):
return other.__radd__(self)
return Bool(must=[self, other])

def __invert__(self):
def __invert__(self) -> "Query":
return Bool(must_not=[self])

def __or__(self, other):
def __or__(self, other: "Query") -> "Query":
# make sure we give queries that know how to combine themselves
# preference
if hasattr(other, "__ror__"):
return other.__ror__(self)
return Bool(should=[self, other])

def __and__(self, other):
def __and__(self, other: "Query") -> "Query":
# make sure we give queries that know how to combine themselves
# preference
if hasattr(other, "__rand__"):
Expand All @@ -87,17 +138,17 @@ def __and__(self, other):
class MatchAll(Query):
name = "match_all"

def __add__(self, other):
def __add__(self, other: "Query") -> "Query":
return other._clone()

__and__ = __rand__ = __radd__ = __add__

def __or__(self, other):
def __or__(self, other: "Query") -> "MatchAll":
return self

__ror__ = __or__

def __invert__(self):
def __invert__(self) -> "MatchNone":
return MatchNone()


Expand All @@ -107,17 +158,17 @@ def __invert__(self):
class MatchNone(Query):
name = "match_none"

def __add__(self, other):
def __add__(self, other: "Query") -> "MatchNone":
return self

__and__ = __rand__ = __radd__ = __add__

def __or__(self, other):
def __or__(self, other: "Query") -> "Query":
return other._clone()

__ror__ = __or__

def __invert__(self):
def __invert__(self) -> MatchAll:
return MatchAll()


Expand All @@ -130,7 +181,7 @@ class Bool(Query):
"filter": {"type": "query", "multi": True},
}

def __add__(self, other):
def __add__(self, other: Query) -> "Bool":
q = self._clone()
if isinstance(other, Bool):
q.must += other.must
Expand All @@ -143,7 +194,7 @@ def __add__(self, other):

__radd__ = __add__

def __or__(self, other):
def __or__(self, other: Query) -> Query:
for q in (self, other):
if isinstance(q, Bool) and not any(
(q.must, q.must_not, q.filter, getattr(q, "minimum_should_match", None))
Expand All @@ -168,20 +219,20 @@ def __or__(self, other):
__ror__ = __or__

@property
def _min_should_match(self):
def _min_should_match(self) -> int:
return getattr(
self,
"minimum_should_match",
0 if not self.should or (self.must or self.filter) else 1,
)

def __invert__(self):
def __invert__(self) -> Query:
# Because an empty Bool query is treated like
# MatchAll the inverse should be MatchNone
if not any(chain(self.must, self.filter, self.should, self.must_not)):
return MatchNone()

negations = []
negations: List[Query] = []
for q in chain(self.must, self.filter):
negations.append(~q)

Expand All @@ -195,7 +246,7 @@ def __invert__(self):
return negations[0]
return Bool(should=negations)

def __and__(self, other):
def __and__(self, other: Query) -> Query:
q = self._clone()
if isinstance(other, Bool):
q.must += other.must
Expand Down Expand Up @@ -247,7 +298,7 @@ class FunctionScore(Query):
"functions": {"type": "score_function", "multi": True},
}

def __init__(self, **kwargs):
def __init__(self, **kwargs: Any):
if "functions" in kwargs:
pass
else:
Expand Down
14 changes: 10 additions & 4 deletions elasticsearch_dsl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

import collections.abc
from copy import copy
from typing import Any, Dict, Optional, Type

from typing_extensions import Self

from .exceptions import UnknownDslObject, ValidationException

Expand Down Expand Up @@ -251,7 +254,9 @@ class DslBase(metaclass=DslMeta):
_param_defs = {}

@classmethod
def get_dsl_class(cls, name, default=None):
def get_dsl_class(
cls: Type[Self], name: str, default: Optional[str] = None
) -> Type[Self]:
try:
return cls._classes[name]
except KeyError:
Expand All @@ -261,7 +266,7 @@ def get_dsl_class(cls, name, default=None):
f"DSL class `{name}` does not exist in {cls._type_name}."
)

def __init__(self, _expand__to_dot=None, **params):
def __init__(self, _expand__to_dot: Optional[bool] = None, **params: Any) -> None:
if _expand__to_dot is None:
_expand__to_dot = EXPAND__TO_DOT
self._params = {}
Expand Down Expand Up @@ -351,7 +356,8 @@ def __getattr__(self, name):
return AttrDict(value)
return value

def to_dict(self):
# TODO: This type annotation can probably be made tighter
def to_dict(self) -> Dict[str, Dict[str, Any]]:
"""
Serialize the DSL object to plain dict
"""
Expand Down Expand Up @@ -390,7 +396,7 @@ def to_dict(self):
d[pname] = value
return {self.name: d}

def _clone(self):
def _clone(self) -> Self:
c = self.__class__()
for attr in self._params:
c._params[attr] = copy(self._params[attr])
Expand Down
3 changes: 3 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[mypy-elasticsearch_dsl.query]
# Allow reexport of SF for tests
implicit_reexport = True
34 changes: 33 additions & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.

import subprocess

import nox

SOURCE_FILES = (
Expand All @@ -27,6 +29,11 @@
"utils/",
)

TYPED_FILES = (
"elasticsearch_dsl/query.py",
"tests/test_query.py",
)


@nox.session(
python=[
Expand Down Expand Up @@ -72,10 +79,35 @@ def lint(session):
session.run("black", "--check", "--target-version=py38", *SOURCE_FILES)
session.run("isort", "--check", *SOURCE_FILES)
session.run("python", "utils/run-unasync.py", "--check")
session.run("flake8", "--ignore=E501,E741,W503", *SOURCE_FILES)
session.run("flake8", "--ignore=E501,E741,W503,E704", *SOURCE_FILES)
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)


@nox.session(python="3.8")
def type_check(session):
session.install("mypy", ".[develop]")
errors = []
popen = subprocess.Popen(
"mypy --strict elasticsearch_dsl tests",
env=session.env,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)

mypy_output = ""
while popen.poll() is None:
mypy_output += popen.stdout.read(8192).decode()
mypy_output += popen.stdout.read().decode()

for line in mypy_output.split("\n"):
filepath = line.partition(":")[0]
if filepath in TYPED_FILES:
errors.append(line)
if errors:
session.error("\n" + "\n".join(sorted(set(errors))))


@nox.session()
def docs(session):
session.install(".[develop]")
Expand Down

0 comments on commit e68585b

Please sign in to comment.