Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow writers to overwrite existing data #594

Merged
merged 7 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 12 additions & 1 deletion streaming/base/format/base/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class Writer(ABC):
file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``.
retry (int): Number of times to retry uploading a file to a remote location.
Default to ``2``.
exist_ok (bool): If the local directory exists and is not empty, whether to overwrite
the content or raise an error. `False` raises an error. `True` deletes the
content and starts fresh. Defaults to `False`.
"""

format: str = '' # Name of the format (like "mds", "csv", "json", etc).
Expand Down Expand Up @@ -100,7 +103,8 @@ def __init__(self,

# Validate keyword arguments
invalid_kwargs = [
arg for arg in kwargs.keys() if arg not in ('progress_bar', 'max_workers', 'retry')
arg for arg in kwargs.keys()
if arg not in ('progress_bar', 'max_workers', 'retry', 'exist_ok')
JAEarly marked this conversation as resolved.
Show resolved Hide resolved
]
if invalid_kwargs:
raise ValueError(f'Invalid Writer argument(s): {invalid_kwargs} ')
Expand All @@ -116,6 +120,13 @@ def __init__(self,

self.shards = []

# Remove local directory if requested prior to creating writer
local = out if isinstance(out, str) else out[0]
if os.path.exists(local) and kwargs.get('exist_ok', False):
XiaohanZhangCMU marked this conversation as resolved.
Show resolved Hide resolved
logger.warning(
f'Directory {local} exists and is not empty; exist_ok is set to True so will remove contents.'
)
shutil.rmtree(local)
self.cloud_writer = CloudUploader.get(out, keep_local, kwargs.get('progress_bar', False),
kwargs.get('retry', 2))
self.local = self.cloud_writer.local
Expand Down
3 changes: 3 additions & 0 deletions streaming/base/format/json/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class JSONWriter(SplitWriter):
max_workers (int): Maximum number of threads used to upload output dataset files in
parallel to a remote location. One thread is responsible for uploading one shard
file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``.
exist_ok (bool): If the local directory exists and is not empty, whether to overwrite
the content or raise an error. `False` raises an error. `True` deletes the
content and starts fresh. Defaults to `False`.
"""

format = 'json'
Expand Down
3 changes: 3 additions & 0 deletions streaming/base/format/mds/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class MDSWriter(JointWriter):
max_workers (int): Maximum number of threads used to upload output dataset files in
parallel to a remote location. One thread is responsible for uploading one shard
file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``.
exist_ok (bool): If the local directory exists and is not empty, whether to overwrite
the content or raise an error. `False` raises an error. `True` deletes the
content and starts fresh. Defaults to `False`.
"""

format = 'mds'
Expand Down
3 changes: 3 additions & 0 deletions streaming/base/format/xsv/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class XSVWriter(SplitWriter):
max_workers (int): Maximum number of threads used to upload output dataset files in
parallel to a remote location. One thread is responsible for uploading one shard
file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``.
exist_ok (bool): If the local directory exists and is not empty, whether to overwrite
the content or raise an error. `False` raises an error. `True` deletes the
content and starts fresh. Defaults to `False`.
"""

format = 'xsv'
Expand Down
103 changes: 103 additions & 0 deletions tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,34 @@ def test_dataset_iter_determinism(self, local_remote_dir: Tuple[str, str], num_s
for before, after in zip(dataset, mds_dataset):
assert before == after

def test_exist_ok(self, local_remote_dir: Tuple[str, str]) -> None:
num_samples = 1000
size_limit = 4096
local, _ = local_remote_dir
dataset = SequenceDataset(num_samples)
columns = dict(zip(dataset.column_names, dataset.column_encodings))

# Write entire dataset initially
with MDSWriter(out=local, columns=columns, size_limit=size_limit) as out:
for sample in dataset:
out.write(sample)
num_orig_files = len(os.listdir(local))

# Write single sample with exist_ok set to True
with MDSWriter(out=local, columns=columns, size_limit=size_limit, exist_ok=True) as out:
out.write(dataset[0])
num_files = len(os.listdir(local))

# Two files for single sample (index.json and one shard)
assert num_files == 2
# Should be more files generated for the entire dataset, which are then deleted as exist_ok is True
assert num_orig_files > num_files

# Check exception is raised when exist_ok is False and local already exists
with pytest.raises(FileExistsError, match='Directory is not empty'):
with MDSWriter(out=local, columns=columns, size_limit=size_limit) as out:
out.write(dataset[0])


class TestJSONWriter:

Expand Down Expand Up @@ -177,6 +205,34 @@ def test_dataset_iter_determinism(self, local_remote_dir: Tuple[str, str], num_s
for before, after in zip(dataset, mds_dataset):
assert before == after

def test_exist_ok(self, local_remote_dir: Tuple[str, str]) -> None:
num_samples = 1000
size_limit = 4096
local, _ = local_remote_dir
dataset = SequenceDataset(num_samples)
columns = dict(zip(dataset.column_names, dataset.column_encodings))

# Write entire dataset initially
with JSONWriter(out=local, columns=columns, size_limit=size_limit) as out:
for sample in dataset:
out.write(sample)
num_orig_files = len(os.listdir(local))

# Write single sample with exist_ok set to True
with JSONWriter(out=local, columns=columns, size_limit=size_limit, exist_ok=True) as out:
out.write(dataset[0])
num_files = len(os.listdir(local))

# Three files for single sample (index.json, one shard, and one shard metadata)
assert num_files == 3
# Should be more files generated for the entire dataset, which are then deleted as exist_ok is True
assert num_orig_files > num_files

# Check exception is raised when exist_ok is False and local already exists
with pytest.raises(FileExistsError, match='Directory is not empty'):
with JSONWriter(out=local, columns=columns, size_limit=size_limit) as out:
out.write(dataset[0])


class TestXSVWriter:

Expand Down Expand Up @@ -256,3 +312,50 @@ def test_dataset_iter_determinism(self, local_remote_dir: Tuple[str, str], num_s
# Ensure sample iterator is deterministic
for before, after in zip(dataset, mds_dataset):
assert before == after

@pytest.mark.parametrize('writer', [XSVWriter, TSVWriter, CSVWriter])
def test_exist_ok(self, local_remote_dir: Tuple[str, str], writer: Any) -> None:
num_samples = 1000
size_limit = 4096
local, _ = local_remote_dir
dataset = SequenceDataset(num_samples)
columns = dict(zip(dataset.column_names, dataset.column_encodings))

# Write entire dataset initially
if writer.__name__ == XSVWriter.__name__:
with writer(out=local, columns=columns, size_limit=size_limit, separator=',') as out:
for sample in dataset:
out.write(sample)
else:
with writer(out=local, columns=columns, size_limit=size_limit) as out:
for sample in dataset:
out.write(sample)
num_orig_files = len(os.listdir(local))

# Write single sample with exist_ok set to True
if writer.__name__ == XSVWriter.__name__:
with writer(out=local,
columns=columns,
size_limit=size_limit,
separator=',',
exist_ok=True) as out:
out.write(dataset[0])
else:
with writer(out=local, columns=columns, size_limit=size_limit, exist_ok=True) as out:
out.write(dataset[0])
num_files = len(os.listdir(local))

# Three files for single sample (index.json, one shard, and one shard metadata)
assert num_files == 3
# Should be more files generated for the entire dataset, which are then deleted as exist_ok is True
assert num_orig_files > num_files

# Check exception is raised when exist_ok is False and local already exists
with pytest.raises(FileExistsError, match='Directory is not empty'):
if writer.__name__ == XSVWriter.__name__:
with writer(out=local, columns=columns, size_limit=size_limit,
separator=',') as out:
out.write(dataset[0])
else:
with writer(out=local, columns=columns, size_limit=size_limit) as out:
out.write(dataset[0])