Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sample method to eland.DataFrame #196

Merged
merged 8 commits into from
May 4, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
34 changes: 34 additions & 0 deletions eland/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,40 @@ def tail(self, n: int = 5) -> "DataFrame":
"""
return DataFrame(query_compiler=self._query_compiler.tail(n))

def sample(self, n=None, frac=None):
"""
Return n randomly sample rows or the specify fraction of rows

Parameters
----------
n : int, optional
Number of documents from index to return. Cannot be used with `frac`.
Default = 1 if `frac` = None.
frac : float, optional
Fraction of axis items to return. Cannot be used with `n`.

Returns
-------
eland.DataFrame:
eland DataFrame filtered containing n rows randomly sampled

See Also
--------
:pandas_api_docs:`pandas.DataFrame.sample`
"""

if frac is not None and frac > 1:
sethmlarson marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"Replace has to be set to `True` when "
"upsampling the population `frac` > 1."
)
elif n is not None and frac is None and n % 1 != 0:
raise ValueError("Only integers accepted as `n` values")
elif n is not None and frac is not None:
sethmlarson marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Please enter a value for `frac` OR `n`, not both")

return DataFrame(query_compiler=self._query_compiler.sample(n=n, frac=frac))

def drop(
self,
labels=None,
Expand Down
5 changes: 5 additions & 0 deletions eland/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from eland.tasks import (
HeadTask,
TailTask,
SampleTask,
BooleanFilterTask,
ArithmeticOpFieldsTask,
QueryTermsTask,
Expand Down Expand Up @@ -91,6 +92,10 @@ def tail(self, index, n):
task = TailTask(index.sort_field, n)
self._tasks.append(task)

def sample(self, n):
task = SampleTask(n)
self._tasks.append(task)

def arithmetic_op_fields(self, display_name, arithmetic_series):
if self._arithmetic_op_fields_task is None:
self._arithmetic_op_fields_task = ArithmeticOpFieldsTask(
Expand Down
4 changes: 4 additions & 0 deletions eland/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Optional, Dict, List, Any

from eland.filter import BooleanFilter, NotNull, IsNull, IsIn
from eland.score import RandomScore


class Query:
Expand Down Expand Up @@ -169,5 +170,8 @@ def update_boolean_filter(self, boolean_filter: BooleanFilter) -> None:
else:
self._query = self._query & boolean_filter

def random_score(self) -> None:
self._query = RandomScore(self._query)

def __repr__(self) -> str:
return repr(self.to_search_body())
26 changes: 24 additions & 2 deletions eland/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
import numpy as np
import pandas as pd

from eland import Index
from eland.field_mappings import FieldMappings
from eland.operations import Operations
from eland.filter import QueryFilter
from eland.operations import Operations
from eland import Index
from eland.common import (
ensure_es_client,
DEFAULT_PROGRESS_REPORTING_NUM_ROWS,
Expand Down Expand Up @@ -403,6 +403,28 @@ def tail(self, n):

return result

def sample(self, n=None, frac=None):
result = self.copy()

if n is None and frac is None:
n = 1
elif n is None and frac is not None:
# fetch index size
stats = self._client.indices.stats(
sethmlarson marked this conversation as resolved.
Show resolved Hide resolved
index=self._index_pattern, metric="indexing"
)
index_length = stats["_all"]["primaries"]["indexing"]["index_total"]
n = int(round(frac * index_length))

if n < 0:
raise ValueError(
"A negative number of rows requested. Please provide positive value."
)

result._operations.sample(n)

return result

def es_query(self, query):
return self._update_query(QueryFilter(query))

Expand Down
31 changes: 31 additions & 0 deletions eland/score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright 2020 Elasticsearch BV
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


class RandomScore:
sethmlarson marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, query):

q = {"match_all": {}}
if not query.empty():
q = query.build()

self._score = {"function_score": {"query": q, "random_score": {}}}

def empty(self):
if self._score is None:
return True
return False

def build(self):
return self._score
28 changes: 27 additions & 1 deletion eland/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from eland.actions import HeadAction, TailAction, SortIndexAction
from eland.arithmetics import ArithmeticSeries


if TYPE_CHECKING:
from .actions import PostProcessingAction # noqa: F401
from .filter import BooleanFilter # noqa: F401
Expand Down Expand Up @@ -185,6 +184,33 @@ def __repr__(self) -> str:
return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))"


class SampleTask(SizeTask):
def __init__(self, count):
super().__init__("sample")
self._count = count

def resolve_task(
self,
query_params: QUERY_PARAMS_TYPE,
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
query_params["query"].random_score()

if query_params.get("query_size") is not None:
query_params["query_size"] = min(self._count, query_params["query_size"])
else:
query_params["query_size"] = self._count

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something else to think about here is we want to order by _score (unless pandas maintains the index order after a .sample() call?)

I actually think when I checked this out locally and added the query_params["query_sort_order"] = "score" I found a bug in TailTask not picking up the current query_sort_order when resolving tasks. Something to potentially investigate outside of this issue.

return query_params, post_processing

def size(self) -> int:
return self._count

def __repr__(self) -> str:
return f"('{self._task_type}': ('count': {self._count}))"


class QueryIdsTask(Task):
def __init__(self, must: bool, ids: List[str]):
"""
Expand Down
55 changes: 55 additions & 0 deletions eland/tests/dataframe/test_sample_pytest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2020 Elasticsearch BV
sethmlarson marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# File called _pytest for PyCharm compatibility

from eland.tests.common import TestData
from eland.tests.common import assert_pandas_eland_frame_equal


class TestDataFrameSample(TestData):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually will need to add a test case that calls .sample() and then other operations such a .head(), .agg(), .shape, etc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for your thorough review of this PR, I learn a lot with each comment you made. My doubt with the test is how do I assert is working, I mean what assertion should I check. I already fix some minor issues and believe I can solve the others early next week.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The combination asserts will probably be easier after implementing random_state. Mostly want to verify that we can add additional queries to our .sample() calls without pulling data from ES

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as checking whether .sample() itself is working you could test that calling .sample(10) twice gives you two different sets of rows :)

def test_sample_basic(self):
ed_flights = self.ed_flights()
sample_ed_flights = ed_flights.sample(n=10)._to_pandas()
assert len(sample_ed_flights) == 10

def test_sample_on_boolean_filter(self):
ed_flights = self.ed_flights()
columns = ["timestamp", "OriginAirportID", "DestAirportID", "FlightDelayMin"]
shape = ed_flights[columns].sample(n=5)._to_pandas().shape
assert (5, 4) == shape

def test_sample_head(self):
ed_flights = self.ed_flights()
pd_flights = self.pd_flights()

pd_head_5 = pd_flights.head(5)
ed_head_5 = ed_flights.head(5).sample(5)
assert_pandas_eland_frame_equal(pd_head_5, ed_head_5)

def test_sample_frac_values(self):
ed_flights = self.ed_flights()
pd_flights = self.pd_flights()

pd_head_5 = pd_flights.head(5)
ed_head_5 = ed_flights.head(5).sample(frac=1)
assert_pandas_eland_frame_equal(pd_head_5, ed_head_5)

def test_sample_frac_is(self):
frac = 0.1
ed_flights = self.ed_flights()

ed_flights_sample = ed_flights.sample(frac=frac)._to_pandas()
size = len(ed_flights._to_pandas())
assert len(ed_flights_sample) <= int(round(frac * size))