diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 7184fe25255..5f3d0edec2f 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -3101,6 +3101,60 @@ def test_dataset_from_text_path_type(path_type, text_path, tmp_path): _check_text_dataset(dataset, expected_features) +@pytest.fixture +def data_generator(): + def _gen(): + data = [ + {"col_1": "0", "col_2": 0, "col_3": 0.0}, + {"col_1": "1", "col_2": 1, "col_3": 1.0}, + {"col_1": "2", "col_2": 2, "col_3": 2.0}, + {"col_1": "3", "col_2": 3, "col_3": 3.0}, + ] + for item in data: + yield item + + return _gen + + +def _check_generator_dataset(dataset, expected_features): + assert isinstance(dataset, Dataset) + assert dataset.num_rows == 4 + assert dataset.num_columns == 3 + assert dataset.column_names == ["col_1", "col_2", "col_3"] + for feature, expected_dtype in expected_features.items(): + assert dataset.features[feature].dtype == expected_dtype + + +@pytest.mark.parametrize("keep_in_memory", [False, True]) +def test_dataset_from_generator_keep_in_memory(keep_in_memory, data_generator, tmp_path): + cache_dir = tmp_path / "cache" + expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + with assert_arrow_memory_increases() if keep_in_memory else assert_arrow_memory_doesnt_increase(): + dataset = Dataset.from_generator(data_generator, cache_dir=cache_dir, keep_in_memory=keep_in_memory) + _check_generator_dataset(dataset, expected_features) + + +@pytest.mark.parametrize( + "features", + [ + None, + {"col_1": "string", "col_2": "int64", "col_3": "float64"}, + {"col_1": "string", "col_2": "string", "col_3": "string"}, + {"col_1": "int32", "col_2": "int32", "col_3": "int32"}, + {"col_1": "float32", "col_2": "float32", "col_3": "float32"}, + ], +) +def test_dataset_from_generator_features(features, data_generator, tmp_path): + cache_dir = tmp_path / "cache" + default_expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"} + expected_features = features.copy() if features else default_expected_features + features = ( + Features({feature: Value(dtype) for feature, dtype in features.items()}) if features is not None else None + ) + dataset = Dataset.from_generator(data_generator, features=features, cache_dir=cache_dir) + _check_generator_dataset(dataset, expected_features) + + def test_dataset_to_json(dataset, tmp_path): file_path = tmp_path / "test_path.jsonl" bytes_written = dataset.to_json(path_or_buf=file_path)