Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate config name and data_files in packaged modules #6915

Merged
merged 7 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def _check_legacy_cache2(self, dataset_module: "DatasetModule") -> Optional[str]
and not is_remote_url(self._cache_dir_root)
and not (set(self.config_kwargs) - {"data_files", "data_dir"})
):
from .packaged_modules import _PACKAGED_DATASETS_MODULES
from .packaged_modules import _PACKAGED_DATASETS_MODULES_2_15_HASHES
from .utils._dill import Pickler

def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> str:
Expand All @@ -516,7 +516,7 @@ def update_hash_with_config_parameters(hash: str, config_parameters: dict) -> st
namespace = self.repo_id.split("/")[0] if self.repo_id and self.repo_id.count("/") > 0 else None
with patch.object(Pickler, "_legacy_no_dict_keys_sorting", True):
config_id = self.config.name + "-" + Hasher.hash({"data_files": self.config.data_files})
hash = _PACKAGED_DATASETS_MODULES.get(self.name, "missing")[1]
hash = _PACKAGED_DATASETS_MODULES_2_15_HASHES.get(self.name, "missing")
if (
dataset_module.builder_configs_parameters.metadata_configs
and self.config.name in dataset_module.builder_configs_parameters.metadata_configs
Expand Down
12 changes: 12 additions & 0 deletions src/datasets/packaged_modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ def _hash_python_lines(lines: List[str]) -> str:
"webdataset": (webdataset.__name__, _hash_python_lines(inspect.getsource(webdataset).splitlines())),
}

# get importable module names and hash for caching
_PACKAGED_DATASETS_MODULES_2_15_HASHES = {
"csv": "eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d",
"json": "8bb11242116d547c741b2e8a1f18598ffdd40a1d4f2a2872c7a28b697434bc96",
"pandas": "3ac4ffc4563c796122ef66899b9485a3f1a977553e2d2a8a318c72b8cc6f2202",
"parquet": "ca31c69184d9832faed373922c2acccec0b13a0bb5bbbe19371385c3ff26f1d1",
"arrow": "74f69db2c14c2860059d39860b1f400a03d11bf7fb5a8258ca38c501c878c137",
"text": "c4a140d10f020282918b5dd1b8a49f0104729c6177f60a6b49ec2a365ec69f34",
"imagefolder": "7b7ce5247a942be131d49ad4f3de5866083399a0f250901bd8dc202f8c5f7ce5",
"audiofolder": "d3c1655c66c8f72e4efb5c79e952975fa6e2ce538473a6890241ddbddee9071c",
}

# Used to infer the module to use based on the data files extensions
_EXTENSION_TO_MODULE: Dict[str, Tuple[str, dict]] = {
".csv": ("csv", {}),
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/arrow/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ class ArrowConfig(datasets.BuilderConfig):

features: Optional[datasets.Features] = None

def __post_init__(self):
super().__post_init__()


class Arrow(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = ArrowConfig
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/audiofolder/audiofolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ class AudioFolderConfig(folder_based_builder.FolderBasedBuilderConfig):
drop_labels: bool = None
drop_metadata: bool = None

def __post_init__(self):
super().__post_init__()


class AudioFolder(folder_based_builder.FolderBasedBuilder):
BASE_FEATURE = datasets.Audio
Expand Down
1 change: 1 addition & 0 deletions src/datasets/packaged_modules/csv/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class CsvConfig(datasets.BuilderConfig):
date_format: Optional[str] = None

def __post_init__(self):
super().__post_init__()
if self.delimiter is not None:
self.sep = self.delimiter
if self.column_names is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class FolderBasedBuilderConfig(datasets.BuilderConfig):
drop_labels: bool = None
drop_metadata: bool = None

def __post_init__(self):
super().__post_init__()


class FolderBasedBuilder(datasets.GeneratorBasedBuilder):
"""
Expand Down
4 changes: 3 additions & 1 deletion src/datasets/packaged_modules/generator/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ class GeneratorConfig(datasets.BuilderConfig):
features: Optional[datasets.Features] = None

def __post_init__(self):
assert self.generator is not None, "generator must be specified"
super().__post_init__()
if self.generator is None:
raise ValueError("generator must be specified")

if self.gen_kwargs is None:
self.gen_kwargs = {}
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/imagefolder/imagefolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ class ImageFolderConfig(folder_based_builder.FolderBasedBuilderConfig):
drop_labels: bool = None
drop_metadata: bool = None

def __post_init__(self):
super().__post_init__()


class ImageFolder(folder_based_builder.FolderBasedBuilder):
BASE_FEATURE = datasets.Image
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/json/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class JsonConfig(datasets.BuilderConfig):
chunksize: int = 10 << 20 # 10MB
newlines_in_values: Optional[bool] = None

def __post_init__(self):
super().__post_init__()


class Json(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = JsonConfig
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/pandas/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ class PandasConfig(datasets.BuilderConfig):

features: Optional[datasets.Features] = None

def __post_init__(self):
super().__post_init__()


class Pandas(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = PandasConfig
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/parquet/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class ParquetConfig(datasets.BuilderConfig):
columns: Optional[List[str]] = None
features: Optional[datasets.Features] = None

def __post_init__(self):
super().__post_init__()


class Parquet(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = ParquetConfig
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/spark/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class SparkConfig(datasets.BuilderConfig):

features: Optional[datasets.Features] = None

def __post_init__(self):
super().__post_init__()


def _reorder_dataframe_by_partition(df: "pyspark.sql.DataFrame", new_partition_order: List[int]):
df_combined = df.select("*").where(f"part_id = {new_partition_order[0]}")
Expand Down
1 change: 1 addition & 0 deletions src/datasets/packaged_modules/sql/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class SqlConfig(datasets.BuilderConfig):
features: Optional[datasets.Features] = None

def __post_init__(self):
super().__post_init__()
if self.sql is None:
raise ValueError("sql must be specified")
if self.con is None:
Expand Down
1 change: 1 addition & 0 deletions src/datasets/packaged_modules/text/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class TextConfig(datasets.BuilderConfig):
sample_by: str = "line"

def __post_init__(self, errors):
super().__post_init__()
if errors != "deprecated":
warnings.warn(
"'errors' was deprecated in favor of 'encoding_errors' in version 2.14.0 and will be removed in 3.0.0.\n"
Expand Down
16 changes: 16 additions & 0 deletions tests/packaged_modules/test_arrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesList
from datasets.packaged_modules.arrow.arrow import ArrowConfig


def test_config_raises_when_invalid_name() -> None:
with pytest.raises(InvalidConfigName, match="Bad characters"):
_ = ArrowConfig(name="name-with-*-invalid-character")


@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])])
def test_config_raises_when_invalid_data_files(data_files) -> None:
with pytest.raises(ValueError, match="Expected a DataFilesDict"):
_ = ArrowConfig(name="name", data_files=data_files)
16 changes: 14 additions & 2 deletions tests/packaged_modules/test_audiofolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import soundfile as sf

from datasets import Audio, ClassLabel, Features, Value
from datasets.data_files import DataFilesDict, get_data_patterns
from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesDict, DataFilesList, get_data_patterns
from datasets.download.streaming_download_manager import StreamingDownloadManager
from datasets.packaged_modules.audiofolder.audiofolder import AudioFolder
from datasets.packaged_modules.audiofolder.audiofolder import AudioFolder, AudioFolderConfig

from ..utils import require_sndfile

Expand Down Expand Up @@ -230,6 +231,17 @@ def data_files_with_zip_archives(tmp_path, audio_file):
return data_files_with_zip_archives


def test_config_raises_when_invalid_name() -> None:
with pytest.raises(InvalidConfigName, match="Bad characters"):
_ = AudioFolderConfig(name="name-with-*-invalid-character")


@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])])
def test_config_raises_when_invalid_data_files(data_files) -> None:
with pytest.raises(ValueError, match="Expected a DataFilesDict"):
_ = AudioFolderConfig(name="name", data_files=data_files)


@require_sndfile
# check that labels are inferred correctly from dir names
def test_generate_examples_with_labels(data_files_with_labels_no_metadata, cache_dir):
Expand Down
15 changes: 14 additions & 1 deletion tests/packaged_modules/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import pytest

from datasets import ClassLabel, Features, Image
from datasets.packaged_modules.csv.csv import Csv
from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesList
from datasets.packaged_modules.csv.csv import Csv, CsvConfig

from ..utils import require_pil

Expand Down Expand Up @@ -86,6 +88,17 @@ def csv_file_with_int_list(tmp_path):
return str(filename)


def test_config_raises_when_invalid_name() -> None:
with pytest.raises(InvalidConfigName, match="Bad characters"):
_ = CsvConfig(name="name-with-*-invalid-character")


@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])])
def test_config_raises_when_invalid_data_files(data_files) -> None:
with pytest.raises(ValueError, match="Expected a DataFilesDict"):
_ = CsvConfig(name="name", data_files=data_files)


def test_csv_generate_tables_raises_error_with_malformed_csv(csv_file, malformed_csv_file, caplog):
csv = Csv()
generator = csv._generate_tables([[csv_file, malformed_csv_file]])
Expand Down
14 changes: 13 additions & 1 deletion tests/packaged_modules/test_folder_based_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import pytest

from datasets import ClassLabel, DownloadManager, Features, Value
from datasets.data_files import DataFilesDict, get_data_patterns
from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesDict, DataFilesList, get_data_patterns
from datasets.download.streaming_download_manager import StreamingDownloadManager
from datasets.packaged_modules.folder_based_builder.folder_based_builder import (
FolderBasedBuilder,
Expand Down Expand Up @@ -265,6 +266,17 @@ def data_files_with_zip_archives(tmp_path, auto_text_file):
return data_files_with_zip_archives


def test_config_raises_when_invalid_name() -> None:
with pytest.raises(InvalidConfigName, match="Bad characters"):
_ = FolderBasedBuilderConfig(name="name-with-*-invalid-character")


@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])])
def test_config_raises_when_invalid_data_files(data_files) -> None:
with pytest.raises(ValueError, match="Expected a DataFilesDict"):
_ = FolderBasedBuilderConfig(name="name", data_files=data_files)


def test_inferring_labels_from_data_dirs(data_files_with_labels_no_metadata, cache_dir):
autofolder = DummyFolderBasedBuilder(
data_files=data_files_with_labels_no_metadata, cache_dir=cache_dir, drop_labels=False
Expand Down
16 changes: 14 additions & 2 deletions tests/packaged_modules/test_imagefolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import pytest

from datasets import ClassLabel, Features, Image, Value
from datasets.data_files import DataFilesDict, get_data_patterns
from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesDict, DataFilesList, get_data_patterns
from datasets.download.streaming_download_manager import StreamingDownloadManager
from datasets.packaged_modules.imagefolder.imagefolder import ImageFolder
from datasets.packaged_modules.imagefolder.imagefolder import ImageFolder, ImageFolderConfig

from ..utils import require_pil

Expand Down Expand Up @@ -239,6 +240,17 @@ def data_files_with_zip_archives(tmp_path, image_file):
return data_files_with_zip_archives


def test_config_raises_when_invalid_name() -> None:
with pytest.raises(InvalidConfigName, match="Bad characters"):
_ = ImageFolderConfig(name="name-with-*-invalid-character")


@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])])
def test_config_raises_when_invalid_data_files(data_files) -> None:
with pytest.raises(ValueError, match="Expected a DataFilesDict"):
_ = ImageFolderConfig(name="name", data_files=data_files)


@require_pil
# check that labels are inferred correctly from dir names
def test_generate_examples_with_labels(data_files_with_labels_no_metadata, cache_dir):
Expand Down
15 changes: 14 additions & 1 deletion tests/packaged_modules/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import pytest

from datasets import Features, Value
from datasets.packaged_modules.json.json import Json
from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesList
from datasets.packaged_modules.json.json import Json, JsonConfig


@pytest.fixture
Expand Down Expand Up @@ -171,6 +173,17 @@ def json_file_with_list_of_dicts_with_sorted_columns_field(tmp_path):
return str(path)


def test_config_raises_when_invalid_name() -> None:
with pytest.raises(InvalidConfigName, match="Bad characters"):
_ = JsonConfig(name="name-with-*-invalid-character")


@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])])
def test_config_raises_when_invalid_data_files(data_files) -> None:
with pytest.raises(ValueError, match="Expected a DataFilesDict"):
_ = JsonConfig(name="name", data_files=data_files)


@pytest.mark.parametrize(
"file_fixture, config_kwargs",
[
Expand Down
16 changes: 16 additions & 0 deletions tests/packaged_modules/test_pandas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesList
from datasets.packaged_modules.pandas.pandas import PandasConfig


def test_config_raises_when_invalid_name() -> None:
with pytest.raises(InvalidConfigName, match="Bad characters"):
_ = PandasConfig(name="name-with-*-invalid-character")


@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])])
def test_config_raises_when_invalid_data_files(data_files) -> None:
with pytest.raises(ValueError, match="Expected a DataFilesDict"):
_ = PandasConfig(name="name", data_files=data_files)
16 changes: 16 additions & 0 deletions tests/packaged_modules/test_parquet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesList
from datasets.packaged_modules.parquet.parquet import ParquetConfig


def test_config_raises_when_invalid_name() -> None:
with pytest.raises(InvalidConfigName, match="Bad characters"):
_ = ParquetConfig(name="name-with-*-invalid-character")


@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])])
def test_config_raises_when_invalid_data_files(data_files) -> None:
with pytest.raises(ValueError, match="Expected a DataFilesDict"):
_ = ParquetConfig(name="name", data_files=data_files)
15 changes: 15 additions & 0 deletions tests/packaged_modules/test_spark.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from unittest.mock import patch

import pyspark
import pytest

from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesList
from datasets.packaged_modules.spark.spark import (
Spark,
SparkConfig,
SparkExamplesIterable,
_generate_iterable_examples,
)
Expand All @@ -23,6 +27,17 @@ def _get_expected_row_ids_and_row_dicts_for_partition_order(df, partition_order)
return expected_row_ids_and_row_dicts


def test_config_raises_when_invalid_name() -> None:
with pytest.raises(InvalidConfigName, match="Bad characters"):
_ = SparkConfig(name="name-with-*-invalid-character")


@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])])
def test_config_raises_when_invalid_data_files(data_files) -> None:
with pytest.raises(ValueError, match="Expected a DataFilesDict"):
_ = SparkConfig(name="name", data_files=data_files)


@require_not_windows
@require_dill_gt_0_3_2
def test_repartition_df_if_needed():
Expand Down
16 changes: 16 additions & 0 deletions tests/packaged_modules/test_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from datasets.builder import InvalidConfigName
from datasets.data_files import DataFilesList
from datasets.packaged_modules.sql.sql import SqlConfig


def test_config_raises_when_invalid_name() -> None:
with pytest.raises(InvalidConfigName, match="Bad characters"):
_ = SqlConfig(name="name-with-*-invalid-character")


@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])])
def test_config_raises_when_invalid_data_files(data_files) -> None:
with pytest.raises(ValueError, match="Expected a DataFilesDict"):
_ = SqlConfig(name="name", data_files=data_files)
Loading
Loading