Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sample method to eland.DataFrame #196

Merged
merged 8 commits into from
May 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/source/reference/api/eland.DataFrame.sample.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
eland.DataFrame.sample
======================

.. currentmodule:: eland

.. automethod:: DataFrame.sample
6 changes: 6 additions & 0 deletions docs/source/reference/api/eland.Series.sample.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
eland.Series.sample
===================

.. currentmodule:: eland

.. automethod:: Series.sample
9 changes: 5 additions & 4 deletions docs/source/reference/dataframe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ Attributes and underlying data

DataFrame.index
DataFrame.columns
DataFrame.dtypes
DataFrame.select_dtypes
DataFrame.values
DataFrame.empty
DataFrame.dtypes
DataFrame.select_dtypes
DataFrame.values
DataFrame.empty
DataFrame.shape

Indexing, iteration
Expand All @@ -37,6 +37,7 @@ Indexing, iteration
DataFrame.tail
DataFrame.get
DataFrame.query
DataFrame.sample

Function application, GroupBy & window
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
5 changes: 3 additions & 2 deletions docs/source/reference/series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ Attributes and underlying data

Series.index
Series.shape
Series.name
Series.empty
Series.name
Series.empty

Indexing, iteration
~~~~~~~~~~~~~~~~~~~
Expand All @@ -31,6 +31,7 @@ Indexing, iteration

Series.head
Series.tail
Series.sample

Binary operator functions
~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
39 changes: 39 additions & 0 deletions eland/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,45 @@ def tail(self, n: int = 5) -> "DataFrame":
"""
return DataFrame(query_compiler=self._query_compiler.tail(n))

def sample(
self, n: int = None, frac: float = None, random_state: int = None
) -> "DataFrame":
"""
Return n randomly sample rows or the specify fraction of rows

Parameters
----------
n : int, optional
Number of documents from index to return. Cannot be used with `frac`.
Default = 1 if `frac` = None.
frac : float, optional
Fraction of axis items to return. Cannot be used with `n`.
random_state : int, optional
Seed for the random number generator.

Returns
-------
eland.DataFrame:
eland DataFrame filtered containing n rows randomly sampled

See Also
--------
:pandas_api_docs:`pandas.DataFrame.sample`
"""

if frac is not None and not (0.0 < frac <= 1.0):
raise ValueError("`frac` must be between 0. and 1.")
elif n is not None and frac is None and n % 1 != 0:
raise ValueError("Only integers accepted as `n` values")
elif (n is not None) == (frac is not None):
raise ValueError("Please enter a value for `frac` OR `n`, not both")

return DataFrame(
query_compiler=self._query_compiler.sample(
n=n, frac=frac, random_state=random_state
)
)

def drop(
self,
labels=None,
Expand Down
16 changes: 16 additions & 0 deletions eland/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,19 @@ class QueryFilter(BooleanFilter):
def __init__(self, query: Dict[str, Any]) -> None:
super().__init__()
self._filter = query


class MatchAllFilter(QueryFilter):
def __init__(self) -> None:
super().__init__({"match_all": {}})


class RandomScoreFilter(QueryFilter):
def __init__(self, query: QueryFilter, random_state: int) -> None:
q = MatchAllFilter() if query.empty() else query

seed = {}
if random_state is not None:
seed = {"seed": random_state, "field": "_seq_no"}

super().__init__({"function_score": {"query": q.build(), "random_score": seed}})
4 changes: 4 additions & 0 deletions eland/ndframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,7 @@ def head(self, n=5):
@abstractmethod
def tail(self, n=5):
pass

@abstractmethod
sethmlarson marked this conversation as resolved.
Show resolved Hide resolved
def sample(self, n=None, frac=None, random_state=None):
pass
5 changes: 5 additions & 0 deletions eland/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from eland.tasks import (
HeadTask,
TailTask,
SampleTask,
BooleanFilterTask,
ArithmeticOpFieldsTask,
QueryTermsTask,
Expand Down Expand Up @@ -84,6 +85,10 @@ def tail(self, index, n):
task = TailTask(index.sort_field, n)
self._tasks.append(task)

def sample(self, index, n, random_state):
task = SampleTask(index.sort_field, n, random_state)
self._tasks.append(task)

def arithmetic_op_fields(self, display_name, arithmetic_series):
if self._arithmetic_op_fields_task is None:
self._arithmetic_op_fields_task = ArithmeticOpFieldsTask(
Expand Down
5 changes: 4 additions & 1 deletion eland/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from copy import deepcopy
from typing import Optional, Dict, List, Any

from eland.filter import BooleanFilter, NotNull, IsNull, IsIn
from eland.filter import RandomScoreFilter, BooleanFilter, NotNull, IsNull, IsIn


class Query:
Expand Down Expand Up @@ -152,5 +152,8 @@ def update_boolean_filter(self, boolean_filter: BooleanFilter) -> None:
else:
self._query = self._query & boolean_filter

def random_score(self, random_state: int) -> None:
self._query = RandomScoreFilter(self._query, random_state)

def __repr__(self) -> str:
return repr(self.to_search_body())
22 changes: 20 additions & 2 deletions eland/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
import numpy as np
import pandas as pd

from eland import Index
from eland.field_mappings import FieldMappings
from eland.operations import Operations
from eland.filter import QueryFilter
from eland.operations import Operations
from eland import Index
from eland.common import (
ensure_es_client,
DEFAULT_PROGRESS_REPORTING_NUM_ROWS,
Expand Down Expand Up @@ -393,6 +393,24 @@ def tail(self, n):

return result

def sample(self, n=None, frac=None, random_state=None):
result = self.copy()

if n is None and frac is None:
n = 1
elif n is None and frac is not None:
index_length = self._index_count()
n = int(round(frac * index_length))

if n < 0:
raise ValueError(
"A negative number of rows requested. Please provide positive value."
)

result._operations.sample(self._index, n, random_state)

return result

def es_query(self, query):
return self._update_query(QueryFilter(query))

Expand Down
3 changes: 3 additions & 0 deletions eland/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ def head(self, n=5):
def tail(self, n=5):
return Series(query_compiler=self._query_compiler.tail(n))

def sample(self, n: int = None, frac: float = None, random_state: int = None):
return Series(query_compiler=self._query_compiler.sample(n, frac, random_state))

def value_counts(self, es_size=10):
"""
Return the value counts for the specified field.
Expand Down
38 changes: 37 additions & 1 deletion eland/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from eland.actions import HeadAction, TailAction, SortIndexAction
from eland.arithmetics import ArithmeticSeries


if TYPE_CHECKING:
from .actions import PostProcessingAction # noqa: F401
from .filter import BooleanFilter # noqa: F401
Expand Down Expand Up @@ -175,6 +174,43 @@ def __repr__(self) -> str:
return f"('{self._task_type}': ('sort_field': '{self._sort_field}', 'count': {self._count}))"


class SampleTask(SizeTask):
def __init__(self, sort_field: str, count: int, random_state: int):
super().__init__("sample")
self._count = count
self._random_state = random_state
self._sort_field = sort_field

def resolve_task(
self,
query_params: "QueryParams",
post_processing: List["PostProcessingAction"],
query_compiler: "QueryCompiler",
) -> RESOLVED_TASK_TYPE:
query_params.query.random_score(self._random_state)

query_sort_field = self._sort_field
query_size = self._count

if query_params.size is not None:
query_params.size = min(query_size, query_params.size)
else:
query_params.size = query_size

if query_params.sort_field is None:
query_params.sort_field = query_sort_field

post_processing.append(SortIndexAction())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something else to think about here is we want to order by _score (unless pandas maintains the index order after a .sample() call?)

I actually think when I checked this out locally and added the query_params["query_sort_order"] = "score" I found a bug in TailTask not picking up the current query_sort_order when resolving tasks. Something to potentially investigate outside of this issue.

return query_params, post_processing

def size(self) -> int:
return self._count

def __repr__(self) -> str:
return f"('{self._task_type}': ('count': {self._count}))"


class QueryIdsTask(Task):
def __init__(self, must: bool, ids: List[str]):
"""
Expand Down
88 changes: 88 additions & 0 deletions eland/tests/dataframe/test_sample_pytest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Licensed to Elasticsearch B.V under one or more agreements.
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
# See the LICENSE file in the project root for more information

# File called _pytest for PyCharm compatibility
import pytest
from pandas.testing import assert_frame_equal

from eland.tests.common import TestData
from eland.utils import eland_to_pandas


class TestDataFrameSample(TestData):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually will need to add a test case that calls .sample() and then other operations such a .head(), .agg(), .shape, etc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for your thorough review of this PR, I learn a lot with each comment you made. My doubt with the test is how do I assert is working, I mean what assertion should I check. I already fix some minor issues and believe I can solve the others early next week.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The combination asserts will probably be easier after implementing random_state. Mostly want to verify that we can add additional queries to our .sample() calls without pulling data from ES

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as checking whether .sample() itself is working you could test that calling .sample(10) twice gives you two different sets of rows :)

SEED = 42

def build_from_index(self, sample_ed_flights):
sample_pd_flights = self.pd_flights_small().loc[
sample_ed_flights.index, sample_ed_flights.columns
]
return sample_pd_flights

def test_sample(self):
ed_flights_small = self.ed_flights_small()
first_sample = ed_flights_small.sample(n=10, random_state=self.SEED)
second_sample = ed_flights_small.sample(n=10, random_state=self.SEED)

assert_frame_equal(
eland_to_pandas(first_sample), eland_to_pandas(second_sample)
)

def test_sample_raises(self):
ed_flights_small = self.ed_flights_small()

with pytest.raises(ValueError):
ed_flights_small.sample(n=10, frac=0.1)

with pytest.raises(ValueError):
ed_flights_small.sample(frac=1.5)

with pytest.raises(ValueError):
ed_flights_small.sample(n=-1)

def test_sample_basic(self):
ed_flights_small = self.ed_flights_small()
sample_ed_flights = ed_flights_small.sample(n=10, random_state=self.SEED)
pd_from_eland = eland_to_pandas(sample_ed_flights)

# build using index
sample_pd_flights = self.build_from_index(pd_from_eland)

assert_frame_equal(sample_pd_flights, pd_from_eland)

def test_sample_frac_01(self):
frac = 0.15
ed_flights = self.ed_flights_small().sample(frac=frac, random_state=self.SEED)
pd_from_eland = eland_to_pandas(ed_flights)
pd_flights = self.build_from_index(pd_from_eland)

assert_frame_equal(pd_flights, pd_from_eland)

# assert right size from pd_flights
size = len(self.pd_flights_small())
assert len(pd_flights) == int(round(frac * size))

def test_sample_on_boolean_filter(self):
ed_flights = self.ed_flights_small()
columns = ["timestamp", "OriginAirportID", "DestAirportID", "FlightDelayMin"]
sample_ed_flights = ed_flights[columns].sample(n=5, random_state=self.SEED)
pd_from_eland = eland_to_pandas(sample_ed_flights)
sample_pd_flights = self.build_from_index(pd_from_eland)

assert_frame_equal(sample_pd_flights, pd_from_eland)

def test_sample_head(self):
ed_flights = self.ed_flights_small()
sample_ed_flights = ed_flights.sample(n=10, random_state=self.SEED)
sample_pd_flights = self.build_from_index(eland_to_pandas(sample_ed_flights))

pd_head_5 = sample_pd_flights.head(5)
ed_head_5 = sample_ed_flights.head(5)
assert_frame_equal(pd_head_5, eland_to_pandas(ed_head_5))

def test_sample_shape(self):
ed_flights = self.ed_flights_small()
sample_ed_flights = ed_flights.sample(n=10, random_state=self.SEED)
sample_pd_flights = self.build_from_index(eland_to_pandas(sample_ed_flights))

assert sample_pd_flights.shape == sample_ed_flights.shape
25 changes: 25 additions & 0 deletions eland/tests/series/test_sample_pytest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Licensed to Elasticsearch B.V under one or more agreements.
# Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
# See the LICENSE file in the project root for more information

# File called _pytest for PyCharm compatibility
import eland as ed
from eland.tests import ES_TEST_CLIENT
from eland.tests import FLIGHTS_INDEX_NAME
from eland.tests.common import TestData
from eland.tests.common import assert_pandas_eland_series_equal


class TestSeriesSample(TestData):
SEED = 42

def build_from_index(self, ed_series):
ed2pd_series = ed_series._to_pandas()
return self.pd_flights()["Carrier"].iloc[ed2pd_series.index]

def test_sample(self):
ed_s = ed.Series(ES_TEST_CLIENT, FLIGHTS_INDEX_NAME, "Carrier")
pd_s = self.build_from_index(ed_s.sample(n=10, random_state=self.SEED))

ed_s_sample = ed_s.sample(n=10, random_state=self.SEED)
assert_pandas_eland_series_equal(pd_s, ed_s_sample)