Skip to content

Commit

Permalink
Allow writers to overwrite existing data (#594)
Browse files Browse the repository at this point in the history
* Allow writers to overwrite existing data

* Add exist_ok arg to writer docs

* Fix linting

* Fix linting

* Move removal code into writer from cloud uploader

* Remove old function

---------

Co-authored-by: Xiaohan Zhang <xiaohanzhang.cmu@gmail.com>
  • Loading branch information
JAEarly and XiaohanZhangCMU committed Feb 14, 2024
1 parent cf532fc commit 3de2e68
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 1 deletion.
13 changes: 12 additions & 1 deletion streaming/base/format/base/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class Writer(ABC):
file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``.
retry (int): Number of times to retry uploading a file to a remote location.
Default to ``2``.
exist_ok (bool): If the local directory exists and is not empty, whether to overwrite
the content or raise an error. `False` raises an error. `True` deletes the
content and starts fresh. Defaults to `False`.
"""

format: str = '' # Name of the format (like "mds", "csv", "json", etc).
Expand Down Expand Up @@ -100,7 +103,8 @@ def __init__(self,

# Validate keyword arguments
invalid_kwargs = [
arg for arg in kwargs.keys() if arg not in ('progress_bar', 'max_workers', 'retry')
arg for arg in kwargs.keys()
if arg not in ('progress_bar', 'max_workers', 'retry', 'exist_ok')
]
if invalid_kwargs:
raise ValueError(f'Invalid Writer argument(s): {invalid_kwargs} ')
Expand All @@ -116,6 +120,13 @@ def __init__(self,

self.shards = []

# Remove local directory if requested prior to creating writer
local = out if isinstance(out, str) else out[0]
if os.path.exists(local) and kwargs.get('exist_ok', False):
logger.warning(
f'Directory {local} exists and is not empty; exist_ok is set to True so will remove contents.'
)
shutil.rmtree(local)
self.cloud_writer = CloudUploader.get(out, keep_local, kwargs.get('progress_bar', False),
kwargs.get('retry', 2))
self.local = self.cloud_writer.local
Expand Down
3 changes: 3 additions & 0 deletions streaming/base/format/json/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class JSONWriter(SplitWriter):
max_workers (int): Maximum number of threads used to upload output dataset files in
parallel to a remote location. One thread is responsible for uploading one shard
file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``.
exist_ok (bool): If the local directory exists and is not empty, whether to overwrite
the content or raise an error. `False` raises an error. `True` deletes the
content and starts fresh. Defaults to `False`.
"""

format = 'json'
Expand Down
3 changes: 3 additions & 0 deletions streaming/base/format/mds/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class MDSWriter(JointWriter):
max_workers (int): Maximum number of threads used to upload output dataset files in
parallel to a remote location. One thread is responsible for uploading one shard
file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``.
exist_ok (bool): If the local directory exists and is not empty, whether to overwrite
the content or raise an error. `False` raises an error. `True` deletes the
content and starts fresh. Defaults to `False`.
"""

format = 'mds'
Expand Down
3 changes: 3 additions & 0 deletions streaming/base/format/xsv/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class XSVWriter(SplitWriter):
max_workers (int): Maximum number of threads used to upload output dataset files in
parallel to a remote location. One thread is responsible for uploading one shard
file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``.
exist_ok (bool): If the local directory exists and is not empty, whether to overwrite
the content or raise an error. `False` raises an error. `True` deletes the
content and starts fresh. Defaults to `False`.
"""

format = 'xsv'
Expand Down
103 changes: 103 additions & 0 deletions tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,34 @@ def test_dataset_iter_determinism(self, local_remote_dir: Tuple[str, str], num_s
for before, after in zip(dataset, mds_dataset):
assert before == after

def test_exist_ok(self, local_remote_dir: Tuple[str, str]) -> None:
num_samples = 1000
size_limit = 4096
local, _ = local_remote_dir
dataset = SequenceDataset(num_samples)
columns = dict(zip(dataset.column_names, dataset.column_encodings))

# Write entire dataset initially
with MDSWriter(out=local, columns=columns, size_limit=size_limit) as out:
for sample in dataset:
out.write(sample)
num_orig_files = len(os.listdir(local))

# Write single sample with exist_ok set to True
with MDSWriter(out=local, columns=columns, size_limit=size_limit, exist_ok=True) as out:
out.write(dataset[0])
num_files = len(os.listdir(local))

# Two files for single sample (index.json and one shard)
assert num_files == 2
# Should be more files generated for the entire dataset, which are then deleted as exist_ok is True
assert num_orig_files > num_files

# Check exception is raised when exist_ok is False and local already exists
with pytest.raises(FileExistsError, match='Directory is not empty'):
with MDSWriter(out=local, columns=columns, size_limit=size_limit) as out:
out.write(dataset[0])


class TestJSONWriter:

Expand Down Expand Up @@ -177,6 +205,34 @@ def test_dataset_iter_determinism(self, local_remote_dir: Tuple[str, str], num_s
for before, after in zip(dataset, mds_dataset):
assert before == after

def test_exist_ok(self, local_remote_dir: Tuple[str, str]) -> None:
num_samples = 1000
size_limit = 4096
local, _ = local_remote_dir
dataset = SequenceDataset(num_samples)
columns = dict(zip(dataset.column_names, dataset.column_encodings))

# Write entire dataset initially
with JSONWriter(out=local, columns=columns, size_limit=size_limit) as out:
for sample in dataset:
out.write(sample)
num_orig_files = len(os.listdir(local))

# Write single sample with exist_ok set to True
with JSONWriter(out=local, columns=columns, size_limit=size_limit, exist_ok=True) as out:
out.write(dataset[0])
num_files = len(os.listdir(local))

# Three files for single sample (index.json, one shard, and one shard metadata)
assert num_files == 3
# Should be more files generated for the entire dataset, which are then deleted as exist_ok is True
assert num_orig_files > num_files

# Check exception is raised when exist_ok is False and local already exists
with pytest.raises(FileExistsError, match='Directory is not empty'):
with JSONWriter(out=local, columns=columns, size_limit=size_limit) as out:
out.write(dataset[0])


class TestXSVWriter:

Expand Down Expand Up @@ -256,3 +312,50 @@ def test_dataset_iter_determinism(self, local_remote_dir: Tuple[str, str], num_s
# Ensure sample iterator is deterministic
for before, after in zip(dataset, mds_dataset):
assert before == after

@pytest.mark.parametrize('writer', [XSVWriter, TSVWriter, CSVWriter])
def test_exist_ok(self, local_remote_dir: Tuple[str, str], writer: Any) -> None:
num_samples = 1000
size_limit = 4096
local, _ = local_remote_dir
dataset = SequenceDataset(num_samples)
columns = dict(zip(dataset.column_names, dataset.column_encodings))

# Write entire dataset initially
if writer.__name__ == XSVWriter.__name__:
with writer(out=local, columns=columns, size_limit=size_limit, separator=',') as out:
for sample in dataset:
out.write(sample)
else:
with writer(out=local, columns=columns, size_limit=size_limit) as out:
for sample in dataset:
out.write(sample)
num_orig_files = len(os.listdir(local))

# Write single sample with exist_ok set to True
if writer.__name__ == XSVWriter.__name__:
with writer(out=local,
columns=columns,
size_limit=size_limit,
separator=',',
exist_ok=True) as out:
out.write(dataset[0])
else:
with writer(out=local, columns=columns, size_limit=size_limit, exist_ok=True) as out:
out.write(dataset[0])
num_files = len(os.listdir(local))

# Three files for single sample (index.json, one shard, and one shard metadata)
assert num_files == 3
# Should be more files generated for the entire dataset, which are then deleted as exist_ok is True
assert num_orig_files > num_files

# Check exception is raised when exist_ok is False and local already exists
with pytest.raises(FileExistsError, match='Directory is not empty'):
if writer.__name__ == XSVWriter.__name__:
with writer(out=local, columns=columns, size_limit=size_limit,
separator=',') as out:
out.write(dataset[0])
else:
with writer(out=local, columns=columns, size_limit=size_limit) as out:
out.write(dataset[0])

0 comments on commit 3de2e68

Please sign in to comment.