Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add features parameter to CSV #685

Merged
merged 4 commits into from Sep 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 5 additions & 2 deletions datasets/csv/csv.py
Expand Up @@ -20,6 +20,7 @@ class CsvConfig(datasets.BuilderConfig):
read_options: pac.ReadOptions = None
parse_options: pac.ParseOptions = None
convert_options: pac.ConvertOptions = None
features: datasets.Features = None

@property
def pa_read_options(self):
Expand All @@ -43,7 +44,9 @@ def pa_parse_options(self):

@property
def pa_convert_options(self):
convert_options = self.convert_options or pac.ConvertOptions()
convert_options = self.convert_options or pac.ConvertOptions(
column_types=self.features.type if self.features is not None else None
)
return convert_options


Expand Down Expand Up @@ -78,6 +81,6 @@ def _generate_tables(self, files):
file,
read_options=self.config.pa_read_options,
parse_options=self.config.pa_parse_options,
convert_options=self.config.convert_options,
convert_options=self.config.pa_convert_options,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

)
yield i, pa_table
2 changes: 1 addition & 1 deletion datasets/text/text.py
Expand Up @@ -104,7 +104,7 @@ def _generate_tables(self, files):
file,
read_options=self.config.pa_read_options,
parse_options=self.config.pa_parse_options,
convert_options=self.config.convert_options,
convert_options=self.config.pa_convert_options,
)
# Uncomment for debugging (will print the Arrow table size and elements)
# logger.warning(f"pa_table: {pa_table} num rows: {pa_table.num_rows}")
Expand Down
14 changes: 9 additions & 5 deletions src/datasets/builder.py
Expand Up @@ -25,14 +25,14 @@
from functools import partial
from typing import Dict, List, Optional, Union

import xxhash
from filelock import FileLock

from . import utils
from .arrow_dataset import Dataset
from .arrow_reader import HF_GCP_BASE_URL, ArrowReader, DatasetNotOnHfGcs, MissingFilesOnHfGcs
from .arrow_writer import ArrowWriter, BeamWriter
from .dataset_dict import DatasetDict
from .fingerprint import Hasher
from .info import (
DATASET_INFO_FILENAME,
DATASET_INFOS_DICT_FILE_NAME,
Expand Down Expand Up @@ -153,6 +153,8 @@ def __init__(

# Prepare config: DatasetConfig contains name, version and description but can be extended by each dataset
config_kwargs = dict((key, value) for key, value in config_kwargs.items() if value is not None)
if "features" in inspect.signature(self.BUILDER_CONFIG_CLASS.__init__).parameters and features is not None:
config_kwargs["features"] = features
self.config = self._create_builder_config(
name,
**config_kwargs,
Expand Down Expand Up @@ -256,7 +258,7 @@ def _create_builder_config(self, name=None, **config_kwargs):
# if not builder_config.description:
# raise ValueError("BuilderConfig %s must have a description" % name)
if builder_config.data_files is not None:
m = xxhash.xxh64()
m = Hasher()
if isinstance(builder_config.data_files, str):
data_files = {"train": [builder_config.data_files]}
elif isinstance(builder_config.data_files, (tuple, list)):
Expand All @@ -269,10 +271,12 @@ def _create_builder_config(self, name=None, **config_kwargs):
else:
raise ValueError("Please provide a valid `data_files` in `DatasetBuilder`")
for key in sorted(data_files.keys()):
m.update(key.encode("utf-8"))
m.update(key)
for data_file in data_files[key]:
m.update(os.path.abspath(data_file).encode("utf-8"))
m.update(str(os.path.getmtime(data_file)).encode("utf-8"))
m.update(os.path.abspath(data_file))
m.update(str(os.path.getmtime(data_file)))
if hasattr(builder_config, "features"):
m.update(builder_config.features)
builder_config.name += "-" + m.hexdigest()
return builder_config

Expand Down
81 changes: 79 additions & 2 deletions tests/test_dataset_common.py
Expand Up @@ -28,8 +28,10 @@
BuilderConfig,
DatasetBuilder,
DownloadConfig,
Features,
GenerateMode,
MockDownloadManager,
Value,
cached_path,
hf_api,
hf_bucket_url,
Expand Down Expand Up @@ -375,13 +377,17 @@ def test_load_real_dataset_all_configs(self, dataset_name):

class TextTest(TestCase):
def test_caching(self):
n_samples = 10
with tempfile.TemporaryDirectory() as tmp_dir:
open(os.path.join(tmp_dir, "text.txt"), "w", encoding="utf-8").write("\n".join("foo" for _ in range(10)))
open(os.path.join(tmp_dir, "text.txt"), "w", encoding="utf-8").write(
"\n".join("foo" for _ in range(n_samples))
)
ds = load_dataset(
"./datasets/text", data_files=os.path.join(tmp_dir, "text.txt"), cache_dir=tmp_dir, split="train"
)
data_file = ds._data_files[0]
fingerprint = ds._fingerprint
self.assertEqual(len(ds), n_samples)
del ds
ds = load_dataset(
"./datasets/text", data_files=os.path.join(tmp_dir, "text.txt"), cache_dir=tmp_dir, split="train"
Expand All @@ -390,10 +396,81 @@ def test_caching(self):
self.assertEqual(ds._fingerprint, fingerprint)
del ds

open(os.path.join(tmp_dir, "text.txt"), "w", encoding="utf-8").write("\n".join("bar" for _ in range(10)))
open(os.path.join(tmp_dir, "text.txt"), "w", encoding="utf-8").write(
"\n".join("bar" for _ in range(n_samples))
)
ds = load_dataset(
"./datasets/text", data_files=os.path.join(tmp_dir, "text.txt"), cache_dir=tmp_dir, split="train"
)
self.assertNotEqual(ds._data_files[0], data_file)
self.assertNotEqual(ds._fingerprint, fingerprint)
del ds


class CsvTest(TestCase):
def test_caching(self):
n_rows = 10

features = Features({"foo": Value("string"), "bar": Value("string")})

with tempfile.TemporaryDirectory() as tmp_dir:
open(os.path.join(tmp_dir, "table.csv"), "w", encoding="utf-8").write(
"\n".join(",".join(["foo", "bar"]) for _ in range(n_rows + 1))
)
ds = load_dataset(
"./datasets/csv", data_files=os.path.join(tmp_dir, "table.csv"), cache_dir=tmp_dir, split="train"
)
data_file = ds._data_files[0]
fingerprint = ds._fingerprint
self.assertEqual(len(ds), n_rows)
del ds
ds = load_dataset(
"./datasets/csv", data_files=os.path.join(tmp_dir, "table.csv"), cache_dir=tmp_dir, split="train"
)
self.assertEqual(ds._data_files[0], data_file)
self.assertEqual(ds._fingerprint, fingerprint)
del ds
ds = load_dataset(
"./datasets/csv",
data_files=os.path.join(tmp_dir, "table.csv"),
cache_dir=tmp_dir,
split="train",
features=features,
)
self.assertNotEqual(ds._data_files[0], data_file)
self.assertNotEqual(ds._fingerprint, fingerprint)
del ds

open(os.path.join(tmp_dir, "table.csv"), "w", encoding="utf-8").write(
"\n".join(",".join(["Foo", "Bar"]) for _ in range(n_rows + 1))
)
ds = load_dataset(
"./datasets/csv", data_files=os.path.join(tmp_dir, "table.csv"), cache_dir=tmp_dir, split="train"
)
self.assertNotEqual(ds._data_files[0], data_file)
self.assertNotEqual(ds._fingerprint, fingerprint)
del ds

def test_features(self):
n_rows = 10
n_cols = 3

def get_features(type):
return Features({str(i): Value(type) for i in range(n_cols)})

with tempfile.TemporaryDirectory() as tmp_dir:
open(os.path.join(tmp_dir, "table.csv"), "w", encoding="utf-8").write(
"\n".join(",".join([str(i) for i in range(n_cols)]) for _ in range(n_rows + 1))
)
for type in ["float64", "int8"]:
features = get_features(type)
ds = load_dataset(
"./datasets/csv",
data_files=os.path.join(tmp_dir, "table.csv"),
cache_dir=tmp_dir,
split="train",
features=features,
)
self.assertEqual(len(ds), n_rows)
self.assertDictEqual(ds.features, features)
del ds