Skip to content

Commit

Permalink
Validate config name and data_files in packaged modules (#6915)
Browse files Browse the repository at this point in the history
* Make configs call super post_init in packaged modules

* Update hash in test

* Add tests

* Add tests for BuilderConfig

* Fix syntax

* use old hash for 2.15 cache reload

---------

Co-authored-by: Quentin Lhoest <lhoest.q@gmail.com>
  • Loading branch information
albertvillanova and lhoestq committed Jun 6, 2024
1 parent 6548e0e commit 5bbbf1b
Show file tree
Hide file tree
Showing 26 changed files with 226 additions and 12 deletions.
4 changes: 2 additions & 2 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def _check_legacy_cache2(self, dataset_module: "DatasetModule") -> Optional[str]
and not is_remote_url(self._cache_dir_root)
and not (set(self.config_kwargs) - {"data_files", "data_dir"})
):
from .packaged_modules import _PACKAGED_DATASETS_MODULES
from .packaged_modules import _PACKAGED_DATASETS_MODULES_2_15_HASHES
from .utils._dill import Pickler

def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> str:
Expand All @@ -516,7 +516,7 @@ def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> st
namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None
with patch.object(Pickler, "_legacy_no_dict_keys_sorting", True):
config_id = self.config.name + "-" + Hasher.hash({"data_files": self.config.data_files})
hash = _PACKAGED_DATASETS_MODULES.get(self.name, "missing")[1]
hash = _PACKAGED_DATASETS_MODULES_2_15_HASHES.get(self.name, "missing")
if (
dataset_module.builder_configs_parameters.metadata_configs
and self.config.name in dataset_module.builder_configs_parameters.metadata_configs
Expand Down
12 changes: 12 additions & 0 deletions src/datasets/packaged_modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ def _hash_python_lines(lines: List[str]) -> str:
"webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())),
}

# get importable module names and hash for caching
_PACKAGED_DATASETS_MODULES_2_15_HASHES = {
"csv": "eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d",
"json": "8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96",
"pandas": "3ac4ffc4563c796122ef66899b9485a3f1a977553e2d2a8a318c72b8cc6f2202",
"parquet": "ca31c69184d9832faed373922c2acccec0b13a0bb5bbbe19371385c3ff26f1d1",
"arrow": "74f69db2c14c2860059d39860b1f400a03d11bf7fb5a8258ca38c501c878c137",
"text": "c4a140d10f020282918b5dd1b8a49f0104729c6177f60a6b49ec2a365ec69f34",
"imagefolder": "7b7ce5247a942be131d49ad4f3de5866083399a0f250901bd8dc202f8c5f7ce5",
"audiofolder": "d3c1655c66c8f72e4efb5c79e952975fa6e2ce538473a6890241ddbddee9071c",
}

# Used to infer the module to use based on the data files extensions
_EXTENSION_TO_MODULE: Dict[str, Tuple[str, dict]] = {
".csv": ("csv", {}),
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/arrow/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class ArrowConfig(datasets.BuilderConfig):

features: Optional[datasets.Features] = None

def __post_init__(self):
super().__post_init__()


class Arrow(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = ArrowConfig
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/audiofolder/audiofolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ class AudioFolderConfig(folder_based_builder.FolderBasedBuilderConfig):
drop_labels: bool = None
drop_metadata: bool = None

def __post_init__(self):
super().__post_init__()


class AudioFolder(folder_based_builder.FolderBasedBuilder):
BASE_FEATURE = datasets.Audio
Expand Down
1 change: 1 addition & 0 deletions src/datasets/packaged_modules/csv/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class CsvConfig(datasets.BuilderConfig):
date_format: Optional[str] = None

def __post_init__(self):
super().__post_init__()
if self.delimiter is not None:
self.sep = self.delimiter
if self.column_names is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class FolderBasedBuilderConfig(datasets.BuilderConfig):
drop_labels: bool = None
drop_metadata: bool = None

def __post_init__(self):
super().__post_init__()


class FolderBasedBuilder(datasets.GeneratorBasedBuilder):
"""
Expand Down
4 changes: 3 additions & 1 deletion src/datasets/packaged_modules/generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ class GeneratorConfig(datasets.BuilderConfig):
features: Optional[datasets.Features] = None

def __post_init__(self):
assert self.generator is not None, "generator must be specified"
super().__post_init__()
if self.generator is None:
raise ValueError("generator must be specified")

if self.gen_kwargs is None:
self.gen_kwargs = {}
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/imagefolder/imagefolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ class ImageFolderConfig(folder_based_builder.FolderBasedBuilderConfig):
drop_labels: bool = None
drop_metadata: bool = None

def __post_init__(self):
super().__post_init__()


class ImageFolder(folder_based_builder.FolderBasedBuilder):
BASE_FEATURE = datasets.Image
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/json/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class JsonConfig(datasets.BuilderConfig):
chunksize: int = 10 << 20 # 10MB
newlines_in_values: Optional[bool] = None

def __post_init__(self):
super().__post_init__()


class Json(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = JsonConfig
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/pandas/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ class PandasConfig(datasets.BuilderConfig):

features: Optional[datasets.Features] = None

def __post_init__(self):
super().__post_init__()


class Pandas(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = PandasConfig
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/parquet/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class ParquetConfig(datasets.BuilderConfig):
columns: Optional[List[str]] = None
features: Optional[datasets.Features] = None

def __post_init__(self):
super().__post_init__()


class Parquet(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = ParquetConfig
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/spark/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class SparkConfig(datasets.BuilderConfig):

features: Optional[datasets.Features] = None

def __post_init__(self):
super().__post_init__()


def _reorder_dataframe_by_partition(df: "pyspark.sql.DataFrame", new_partition_order: List[int]):
df_combined = df.select("*").where(f"part_id = {new_partition_order[0]}")
Expand Down
1 change: 1 addition & 0 deletions src/datasets/packaged_modules/sql/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class SqlConfig(datasets.BuilderConfig):
features: Optional[datasets.Features] = None

def __post_init__(self):
super().__post_init__()
if self.sql is None:
raise ValueError("sql must be specified")
if self.con is None:
Expand Down
1 change: 1 addition & 0 deletions src/datasets/packaged_modules/text/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class TextConfig(datasets.BuilderConfig):
sample_by: str = "line"

def __post_init__(self, errors):
super().__post_init__()
if errors != "deprecated":
warnings.warn(
"'errors' was deprecated in favor of 'encoding_errors' in version 2.14.0 and will be removed in 3.0.0.\n"
Expand Down
16 changes: 16 additions & 0 deletions tests/packaged_modules/test_arrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesList
from datasets.packaged_modules.arrow.arrow import ArrowConfig


def test_config_raises_when_invalid_name() -> None:
with pytest.raises(InvalidConfigName, match="Bad characters"):
_ = ArrowConfig(name="name-with-*-invalid-character")


@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])])
def test_config_raises_when_invalid_data_files(data_files) -> None:
with pytest.raises(ValueError, match="Expected a DataFilesDict"):
_ = ArrowConfig(name="name", data_files=data_files)
16 changes: 14 additions & 2 deletions tests/packaged_modules/test_audiofolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import soundfile as sf

from datasets import Audio, ClassLabel, Features, Value
from datasets.data_files import DataFilesDict, get_data_patterns
from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesDict, DataFilesList, get_data_patterns
from datasets.download.streaming_download_manager import StreamingDownloadManager
from datasets.packaged_modules.audiofolder.audiofolder import AudioFolder
from datasets.packaged_modules.audiofolder.audiofolder import AudioFolder, AudioFolderConfig

from ..utils import require_sndfile

Expand Down Expand Up @@ -230,6 +231,17 @@ def data_files_with_zip_archives(tmp_path, audio_file):
return data_files_with_zip_archives


def test_config_raises_when_invalid_name() -> None:
with pytest.raises(InvalidConfigName, match="Bad characters"):
_ = AudioFolderConfig(name="name-with-*-invalid-character")


@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])])
def test_config_raises_when_invalid_data_files(data_files) -> None:
with pytest.raises(ValueError, match="Expected a DataFilesDict"):
_ = AudioFolderConfig(name="name", data_files=data_files)


@require_sndfile
# check that labels are inferred correctly from dir names
def test_generate_examples_with_labels(data_files_with_labels_no_metadata, cache_dir):
Expand Down
15 changes: 14 additions & 1 deletion tests/packaged_modules/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import pytest

from datasets import ClassLabel, Features, Image
from datasets.packaged_modules.csv.csv import Csv
from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesList
from datasets.packaged_modules.csv.csv import Csv, CsvConfig

from ..utils import require_pil

Expand Down Expand Up @@ -86,6 +88,17 @@ def csv_file_with_int_list(tmp_path):
return str(filename)


def test_config_raises_when_invalid_name() -> None:
with pytest.raises(InvalidConfigName, match="Bad characters"):
_ = CsvConfig(name="name-with-*-invalid-character")


@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])])
def test_config_raises_when_invalid_data_files(data_files) -> None:
with pytest.raises(ValueError, match="Expected a DataFilesDict"):
_ = CsvConfig(name="name", data_files=data_files)


def test_csv_generate_tables_raises_error_with_malformed_csv(csv_file, malformed_csv_file, caplog):
csv = Csv()
generator = csv._generate_tables([[csv_file, malformed_csv_file]])
Expand Down
14 changes: 13 additions & 1 deletion tests/packaged_modules/test_folder_based_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import pytest

from datasets import ClassLabel, DownloadManager, Features, Value
from datasets.data_files import DataFilesDict, get_data_patterns
from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesDict, DataFilesList, get_data_patterns
from datasets.download.streaming_download_manager import StreamingDownloadManager
from datasets.packaged_modules.folder_based_builder.folder_based_builder import (
FolderBasedBuilder,
Expand Down Expand Up @@ -265,6 +266,17 @@ def data_files_with_zip_archives(tmp_path, auto_text_file):
return data_files_with_zip_archives


def test_config_raises_when_invalid_name() -> None:
with pytest.raises(InvalidConfigName, match="Bad characters"):
_ = FolderBasedBuilderConfig(name="name-with-*-invalid-character")


@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])])
def test_config_raises_when_invalid_data_files(data_files) -> None:
with pytest.raises(ValueError, match="Expected a DataFilesDict"):
_ = FolderBasedBuilderConfig(name="name", data_files=data_files)


def test_inferring_labels_from_data_dirs(data_files_with_labels_no_metadata, cache_dir):
autofolder = DummyFolderBasedBuilder(
data_files=data_files_with_labels_no_metadata, cache_dir=cache_dir, drop_labels=False
Expand Down
16 changes: 14 additions & 2 deletions tests/packaged_modules/test_imagefolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import pytest

from datasets import ClassLabel, Features, Image, Value
from datasets.data_files import DataFilesDict, get_data_patterns
from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesDict, DataFilesList, get_data_patterns
from datasets.download.streaming_download_manager import StreamingDownloadManager
from datasets.packaged_modules.imagefolder.imagefolder import ImageFolder
from datasets.packaged_modules.imagefolder.imagefolder import ImageFolder, ImageFolderConfig

from ..utils import require_pil

Expand Down Expand Up @@ -239,6 +240,17 @@ def data_files_with_zip_archives(tmp_path, image_file):
return data_files_with_zip_archives


def test_config_raises_when_invalid_name() -> None:
with pytest.raises(InvalidConfigName, match="Bad characters"):
_ = ImageFolderConfig(name="name-with-*-invalid-character")


@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])])
def test_config_raises_when_invalid_data_files(data_files) -> None:
with pytest.raises(ValueError, match="Expected a DataFilesDict"):
_ = ImageFolderConfig(name="name", data_files=data_files)


@require_pil
# check that labels are inferred correctly from dir names
def test_generate_examples_with_labels(data_files_with_labels_no_metadata, cache_dir):
Expand Down
15 changes: 14 additions & 1 deletion tests/packaged_modules/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import pytest

from datasets import Features, Value
from datasets.packaged_modules.json.json import Json
from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesList
from datasets.packaged_modules.json.json import Json, JsonConfig


@pytest.fixture
Expand Down Expand Up @@ -171,6 +173,17 @@ def json_file_with_list_of_dicts_with_sorted_columns_field(tmp_path):
return str(path)


def test_config_raises_when_invalid_name() -> None:
with pytest.raises(InvalidConfigName, match="Bad characters"):
_ = JsonConfig(name="name-with-*-invalid-character")


@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])])
def test_config_raises_when_invalid_data_files(data_files) -> None:
with pytest.raises(ValueError, match="Expected a DataFilesDict"):
_ = JsonConfig(name="name", data_files=data_files)


@pytest.mark.parametrize(
"file_fixture, config_kwargs",
[
Expand Down
16 changes: 16 additions & 0 deletions tests/packaged_modules/test_pandas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesList
from datasets.packaged_modules.pandas.pandas import PandasConfig


def test_config_raises_when_invalid_name() -> None:
with pytest.raises(InvalidConfigName, match="Bad characters"):
_ = PandasConfig(name="name-with-*-invalid-character")


@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])])
def test_config_raises_when_invalid_data_files(data_files) -> None:
with pytest.raises(ValueError, match="Expected a DataFilesDict"):
_ = PandasConfig(name="name", data_files=data_files)
16 changes: 16 additions & 0 deletions tests/packaged_modules/test_parquet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesList
from datasets.packaged_modules.parquet.parquet import ParquetConfig


def test_config_raises_when_invalid_name() -> None:
with pytest.raises(InvalidConfigName, match="Bad characters"):
_ = ParquetConfig(name="name-with-*-invalid-character")


@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])])
def test_config_raises_when_invalid_data_files(data_files) -> None:
with pytest.raises(ValueError, match="Expected a DataFilesDict"):
_ = ParquetConfig(name="name", data_files=data_files)
15 changes: 15 additions & 0 deletions tests/packaged_modules/test_spark.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from unittest.mock import patch

import pyspark
import pytest

from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesList
from datasets.packaged_modules.spark.spark import (
Spark,
SparkConfig,
SparkExamplesIterable,
_generate_iterable_examples,
)
Expand All @@ -23,6 +27,17 @@ def _get_expected_row_ids_and_row_dicts_for_partition_order(df, partition_order)
return expected_row_ids_and_row_dicts


def test_config_raises_when_invalid_name() -> None:
with pytest.raises(InvalidConfigName, match="Bad characters"):
_ = SparkConfig(name="name-with-*-invalid-character")


@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])])
def test_config_raises_when_invalid_data_files(data_files) -> None:
with pytest.raises(ValueError, match="Expected a DataFilesDict"):
_ = SparkConfig(name="name", data_files=data_files)


@require_not_windows
@require_dill_gt_0_3_2
def test_repartition_df_if_needed():
Expand Down
16 changes: 16 additions & 0 deletions tests/packaged_modules/test_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesList
from datasets.packaged_modules.sql.sql import SqlConfig


def test_config_raises_when_invalid_name() -> None:
with pytest.raises(InvalidConfigName, match="Bad characters"):
_ = SqlConfig(name="name-with-*-invalid-character")


@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])])
def test_config_raises_when_invalid_data_files(data_files) -> None:
with pytest.raises(ValueError, match="Expected a DataFilesDict"):
_ = SqlConfig(name="name", data_files=data_files)
Loading

0 comments on commit 5bbbf1b

Please sign in to comment.