Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support typed 'elasticsearch-py' and add 'py.typed' #295

Merged
merged 1 commit into from Oct 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions MANIFEST.in
@@ -1,2 +1,3 @@
include LICENSE.txt
include README.md
recursive-include eland py.typed
15 changes: 10 additions & 5 deletions eland/common.py
Expand Up @@ -22,7 +22,7 @@

import numpy as np # type: ignore
import pandas as pd # type: ignore
from elasticsearch import Elasticsearch # type: ignore
from elasticsearch import Elasticsearch

# Default number of rows displayed (different to pandas where ALL could be displayed)
DEFAULT_NUM_ROWS_DISPLAYED = 60
Expand Down Expand Up @@ -86,7 +86,7 @@ def from_string(order: str) -> "SortOrder":


def elasticsearch_date_to_pandas_date(
value: Union[int, str], date_format: str
value: Union[int, str], date_format: Optional[str]
) -> pd.Timestamp:
"""
Given a specific Elasticsearch format for a date datatype, returns the
Expand Down Expand Up @@ -298,6 +298,7 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
"""Tags the current ES client with a cached '_eland_es_version'
property if one doesn't exist yet for the current Elasticsearch version.
"""
eland_es_version: Tuple[int, int, int]
if not hasattr(es_client, "_eland_es_version"):
version_info = es_client.info()["version"]["number"]
match = re.match(r"^(\d+)\.(\d+)\.(\d+)", version_info)
Expand All @@ -306,6 +307,10 @@ def es_version(es_client: Elasticsearch) -> Tuple[int, int, int]:
f"Unable to determine Elasticsearch version. "
f"Received: {version_info}"
)
major, minor, patch = [int(x) for x in match.groups()]
es_client._eland_es_version = (major, minor, patch)
return cast(Tuple[int, int, int], es_client._eland_es_version)
eland_es_version = cast(
Tuple[int, int, int], tuple([int(x) for x in match.groups()])
)
es_client._eland_es_version = eland_es_version # type: ignore
else:
eland_es_version = es_client._eland_es_version # type: ignore
return eland_es_version
6 changes: 3 additions & 3 deletions eland/etl.py
Expand Up @@ -20,8 +20,8 @@
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union

import pandas as pd # type: ignore
from elasticsearch import Elasticsearch # type: ignore
from elasticsearch.helpers import parallel_bulk # type: ignore
from elasticsearch import Elasticsearch
from elasticsearch.helpers import parallel_bulk
from pandas.io.parsers import _c_parser_defaults # type: ignore

from eland import DataFrame
Expand Down Expand Up @@ -240,7 +240,7 @@ def action_generator(
pd_df, es_dropna, use_pandas_index_for_es_ids, es_dest_index
),
thread_count=thread_count,
chunk_size=chunksize / thread_count,
chunk_size=int(chunksize / thread_count),
),
maxlen=0,
)
Expand Down
10 changes: 6 additions & 4 deletions eland/ml/ml_model.py
Expand Up @@ -18,7 +18,7 @@
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast

import elasticsearch # type: ignore
import elasticsearch
import numpy as np # type: ignore

from eland.common import ensure_es_client, es_version
Expand Down Expand Up @@ -447,11 +447,13 @@ def _trained_model_config(self) -> Dict[str, Any]:
# In Elasticsearch 7.7 and earlier you can't get
# target type without pulling the model definition
# so we check the version first.
kwargs = {}
if es_version(self._client) < (7, 8):
kwargs["include_model_definition"] = True
resp = self._client.ml.get_trained_models(
model_id=self._model_id, include_model_definition=True
)
else:
resp = self._client.ml.get_trained_models(model_id=self._model_id)

resp = self._client.ml.get_trained_models(model_id=self._model_id, **kwargs)
if resp["count"] > 1:
raise ValueError(f"Model ID {self._model_id!r} wasn't unambiguous")
elif resp["count"] == 0:
Expand Down
45 changes: 27 additions & 18 deletions eland/ndframe.py
Expand Up @@ -17,13 +17,15 @@

import sys
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import pandas as pd
import pandas as pd # type: ignore

from eland.query_compiler import QueryCompiler

if TYPE_CHECKING:
from elasticsearch import Elasticsearch

from eland.index import Index

"""
Expand Down Expand Up @@ -55,12 +57,14 @@
class NDFrame(ABC):
def __init__(
self,
es_client=None,
es_index_pattern=None,
columns=None,
es_index_field=None,
_query_compiler=None,
):
es_client: Optional[
Union[str, List[str], Tuple[str, ...], "Elasticsearch"]
] = None,
es_index_pattern: Optional[str] = None,
columns: Optional[List[str]] = None,
es_index_field: Optional[str] = None,
_query_compiler: Optional[QueryCompiler] = None,
) -> None:
"""
pandas.DataFrame/Series like API that proxies into Elasticsearch index(es).

Expand Down Expand Up @@ -134,7 +138,7 @@ def dtypes(self) -> pd.Series:
return self._query_compiler.dtypes

@property
def es_dtypes(self):
def es_dtypes(self) -> pd.Series:
"""
Return the Elasticsearch dtypes in the index

Expand All @@ -155,7 +159,7 @@ def es_dtypes(self):
"""
return self._query_compiler.es_dtypes

def _build_repr(self, num_rows) -> pd.DataFrame:
def _build_repr(self, num_rows: int) -> pd.DataFrame:
# self could be Series or DataFrame
if len(self.index) <= num_rows:
return self.to_pandas()
Expand Down Expand Up @@ -639,20 +643,25 @@ def describe(self) -> pd.DataFrame:
return self._query_compiler.describe()

@abstractmethod
def to_pandas(self, show_progress=False):
pass
def to_pandas(self, show_progress: bool = False) -> pd.DataFrame:
raise NotImplementedError

@abstractmethod
def head(self, n=5):
pass
def head(self, n: int = 5) -> "NDFrame":
raise NotImplementedError

@abstractmethod
def tail(self, n=5):
pass
def tail(self, n: int = 5) -> "NDFrame":
raise NotImplementedError

@abstractmethod
def sample(self, n=None, frac=None, random_state=None):
pass
def sample(
self,
n: Optional[int] = None,
frac: Optional[float] = None,
random_state: Optional[int] = None,
) -> "NDFrame":
raise NotImplementedError

@property
def shape(self) -> Tuple[int, ...]:
Expand Down
Empty file added eland/py.typed
Empty file.
8 changes: 4 additions & 4 deletions eland/query_compiler.py
Expand Up @@ -94,11 +94,11 @@ def __init__(
self._operations = Operations()

@property
def index(self):
def index(self) -> Index:
return self._index

@property
def columns(self):
def columns(self) -> pd.Index:
columns = self._mappings.display_names

return pd.Index(columns)
Expand All @@ -120,11 +120,11 @@ def add_scripted_field(self, scripted_field_name, display_name, pd_dtype):
return result

@property
def dtypes(self):
def dtypes(self) -> pd.Series:
return self._mappings.dtypes()

@property
def es_dtypes(self):
def es_dtypes(self) -> pd.Series:
return self._mappings.es_dtypes()

# END Index, columns, and dtypes objects
Expand Down
1 change: 1 addition & 0 deletions noxfile.py
Expand Up @@ -68,6 +68,7 @@ def format(session):
@nox.session(reuse_venv=True)
def lint(session):
session.install("black", "flake8", "mypy", "isort")
session.install("--pre", "elasticsearch")
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)
session.run("black", "--check", "--target-version=py36", *SOURCE_FILES)
session.run("isort", "--check", *SOURCE_FILES)
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Expand Up @@ -72,6 +72,9 @@
packages=find_packages(include=["eland", "eland.*"]),
install_requires=["elasticsearch>=7.7", "pandas>=1", "matplotlib", "numpy"],
python_requires=">=3.6",
package_data={"eland": ["py.typed"]},
include_package_data=True,
zip_safe=False,
extras_require={
"xgboost": ["xgboost>=0.90,<2"],
"scikit-learn": ["scikit-learn>=0.22.1,<1"],
Expand Down