Skip to content

Commit

Permalink
FIX-#3510: Correct processing of callable skiprows parameter of `re…
Browse files Browse the repository at this point in the history
…ad_csv` function (#3511)

Co-authored-by: Vasily Litvinov <vasilij.n.litvinov@intel.com>
Signed-off-by: Alexander Myskov <alexander.myskov@intel.com>
  • Loading branch information
amyskov and vnlitvinov committed Dec 2, 2021
1 parent 7a7ae18 commit 7c84758
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 4 deletions.
38 changes: 34 additions & 4 deletions modin/core/io/text/text_file_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,10 +889,10 @@ def _get_new_qc(
index=index_range.delete(skiprows_md)
)
elif callable(skiprows_md):
mod_index = skiprows_md(index_range)
if not isinstance(mod_index, np.ndarray):
mod_index = mod_index.to_numpy("bool")
view_idx = index_range[~mod_index]
skip_mask = cls._get_skip_mask(index_range, skiprows_md)
if not isinstance(skip_mask, np.ndarray):
skip_mask = skip_mask.to_numpy("bool")
view_idx = index_range[~skip_mask]
new_query_compiler = new_query_compiler.view(index=view_idx)
else:
raise TypeError(
Expand Down Expand Up @@ -1037,3 +1037,33 @@ def _read(cls, filepath_or_buffer: FilePathOrBuffer, **kwargs):
nrows=kwargs["nrows"] if should_handle_skiprows else None,
)
return new_query_compiler

@classmethod
def _get_skip_mask(cls, rows_index: pandas.Index, skiprows: Callable):
"""
Get mask of skipped by callable `skiprows` rows.
Parameters
----------
rows_index : pandas.Index
Rows index to get mask for.
skiprows : Callable
Callable to check whether row index should be skipped.
Returns
-------
pandas.Index
"""
try:
# direct `skiprows` call is more efficient than using of
# map method, but in some cases it can work incorrectly, e.g.
# when `skiprows` contains `in` operator
mask = skiprows(rows_index)
assert is_list_like(mask)
except (ValueError, TypeError, AssertionError):
# ValueError can be raised if `skiprows` callable contains membership operator
# TypeError is raised if `skiprows` callable contains bitwise operator
# AssertionError is raised if unexpected behavior was detected
mask = rows_index.map(skiprows)

return mask
24 changes: 24 additions & 0 deletions modin/pandas/test/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,30 @@ def test_read_csv_empty_frame(self):
index_col="col1",
)

@pytest.mark.parametrize(
"skiprows",
[
lambda x: x > 20,
lambda x: True,
lambda x: x in [10, 20],
pytest.param(
lambda x: x << 10,
marks=pytest.mark.skipif(
condition="config.getoption('--simulate-cloud').lower() != 'off'",
reason="The reason of tests fail in `cloud` mode is unknown for now - issue #2340",
),
),
],
)
def test_read_csv_skiprows_corner_cases(self, skiprows):
eval_io(
fn_name="read_csv",
check_kwargs_callable=not callable(skiprows),
# read_csv kwargs
filepath_or_buffer=pytest.csvs_names["test_read_csv_regular"],
skiprows=skiprows,
)


class TestTable:
def test_read_table(self, make_csv_file):
Expand Down

0 comments on commit 7c84758

Please sign in to comment.