Skip to content

Commit

Permalink
Merge pull request #4 from dagardner-nv/david-chw-dfp-fixes
Browse files Browse the repository at this point in the history
Fix unittests
  • Loading branch information
cwharris committed Jul 14, 2023
2 parents f7aa395 + c98c99d commit f7c4057
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 9 deletions.
8 changes: 8 additions & 0 deletions tests/examples/digital_fingerprinting/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ def dask_distributed(fail_missing: bool):
yield import_or_skip("dask.distributed", reason=SKIP_REASON, fail_missing=fail_missing)


@pytest.fixture(autouse=True, scope='session')
def dask_cuda(fail_missing: bool):
"""
Mark tests requiring dask.distributed
"""
yield import_or_skip("dask_cuda", reason=SKIP_REASON, fail_missing=fail_missing)


@pytest.fixture(autouse=True, scope='session')
def mlflow(fail_missing: bool):
"""
Expand Down
6 changes: 3 additions & 3 deletions tests/examples/digital_fingerprinting/test_dfp_file_to_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_constructor(config: Config):
@mock.patch('multiprocessing.get_context')
@mock.patch('dask.config')
@mock.patch('dask.distributed.Client')
@mock.patch('dask.distributed.LocalCluster')
@mock.patch('dask_cuda.LocalCUDACluster')
@mock.patch('dfp.stages.dfp_file_to_df._single_object_to_dataframe')
def test_get_or_create_dataframe_from_s3_batch_cache_miss(mock_obf_to_df: mock.MagicMock,
mock_dask_cluster: mock.MagicMock,
Expand Down Expand Up @@ -198,7 +198,7 @@ def test_get_or_create_dataframe_from_s3_batch_cache_miss(mock_obf_to_df: mock.M
@mock.patch('multiprocessing.get_context')
@mock.patch('dask.config')
@mock.patch('dask.distributed.Client')
@mock.patch('dask.distributed.LocalCluster')
@mock.patch('dask_cuda.LocalCUDACluster')
@mock.patch('dfp.stages.dfp_file_to_df._single_object_to_dataframe')
def test_get_or_create_dataframe_from_s3_batch_cache_hit(mock_obf_to_df: mock.MagicMock,
mock_dask_cluster: mock.MagicMock,
Expand Down Expand Up @@ -266,7 +266,7 @@ def test_get_or_create_dataframe_from_s3_batch_cache_hit(mock_obf_to_df: mock.Ma
@mock.patch('multiprocessing.get_context')
@mock.patch('dask.config')
@mock.patch('dask.distributed.Client')
@mock.patch('dask.distributed.LocalCluster')
@mock.patch('dask_cuda.LocalCUDACluster')
@mock.patch('dfp.stages.dfp_file_to_df._single_object_to_dataframe')
def test_get_or_create_dataframe_from_s3_batch_none_noop(mock_obf_to_df: mock.MagicMock,
mock_dask_cluster: mock.MagicMock,
Expand Down
18 changes: 13 additions & 5 deletions tests/test_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ def dask_distributed(fail_missing: bool):
fail_missing=fail_missing)


@pytest.fixture(autouse=True, scope='session')
def dask_cuda(fail_missing: bool):
"""
Mark tests requiring dask.distributed
"""
yield import_or_skip("dask_cuda", reason="Downloader requires dask_cuda", fail_missing=fail_missing)


@pytest.mark.usefixtures("restore_environ")
@pytest.mark.parametrize('use_env', [True, False])
@pytest.mark.parametrize('dl_method', ["single_thread", "multiprocess", "multiprocessing", "dask", "dask_thread"])
Expand Down Expand Up @@ -83,7 +91,7 @@ def test_constructor_invalid_dltype(use_env: bool):
@pytest.mark.usefixtures("restore_environ")
@pytest.mark.parametrize('dl_method,use_processes', [("dask", True), ("dask_thread", False)])
@mock.patch('dask.config')
@mock.patch('dask.distributed.LocalCluster')
@mock.patch('dask_cuda.LocalCUDACluster')
def test_get_dask_cluster(mock_dask_cluster: mock.MagicMock,
mock_dask_config: mock.MagicMock,
dl_method: str,
Expand All @@ -93,11 +101,11 @@ def test_get_dask_cluster(mock_dask_cluster: mock.MagicMock,
assert downloader.get_dask_cluster() is mock_dask_cluster

mock_dask_config.set.assert_called_once()
mock_dask_cluster.assert_called_once_with(start=True, processes=use_processes)
mock_dask_cluster.assert_called_once_with()


@mock.patch('dask.config')
@mock.patch('dask.distributed.LocalCluster')
@mock.patch('dask_cuda.LocalCUDACluster')
@pytest.mark.parametrize('dl_method', ["dask", "dask_thread"])
def test_close(mock_dask_cluster: mock.MagicMock, mock_dask_config: mock.MagicMock, dl_method: str):
mock_dask_cluster.return_value = mock_dask_cluster
Expand All @@ -111,7 +119,7 @@ def test_close(mock_dask_cluster: mock.MagicMock, mock_dask_config: mock.MagicMo
mock_dask_cluster.close.assert_called_once()


@mock.patch('dask.distributed.LocalCluster')
@mock.patch('dask_cuda.LocalCUDACluster')
@pytest.mark.parametrize('dl_method', ["single_thread", "multiprocess", "multiprocessing"])
def test_close_noop(mock_dask_cluster: mock.MagicMock, dl_method: str):
mock_dask_cluster.return_value = mock_dask_cluster
Expand All @@ -129,7 +137,7 @@ def test_close_noop(mock_dask_cluster: mock.MagicMock, dl_method: str):
@mock.patch('multiprocessing.get_context')
@mock.patch('dask.config')
@mock.patch('dask.distributed.Client')
@mock.patch('dask.distributed.LocalCluster')
@mock.patch('dask_cuda.LocalCUDACluster')
def test_download(mock_dask_cluster: mock.MagicMock,
mock_dask_client: mock.MagicMock,
mock_dask_config: mock.MagicMock,
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/nvt/test_schema_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_get_ci_column_selector_increment_column():
dtype="datetime64[ns]",
groupby_column="groupby_col")
result = _get_ci_column_selector(col_info)
assert result == "original_name"
assert result == ["groupby_col", "original_name"]


def test_get_ci_column_selector_string_cat_column():
Expand Down

0 comments on commit f7c4057

Please sign in to comment.