From c1d443288373ceaf50aee665a9bc9831a82356a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 31 May 2024 11:34:41 -0400 Subject: [PATCH] Fix one-off edge case in split_lazy (#1347) --- lhotse/utils.py | 8 ++++++-- test/test_manipulation.py | 11 +++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/lhotse/utils.py b/lhotse/utils.py index 23ed1fd45..fb171537f 100644 --- a/lhotse/utils.py +++ b/lhotse/utils.py @@ -310,9 +310,13 @@ def split_manifest_lazy( if prefix == "": prefix = "split" - items = iter(it) split_idx = start_idx splits = [] + items = iter(it) + try: + item = next(items) + except StopIteration: + return splits while True: try: written = 0 @@ -321,9 +325,9 @@ def split_manifest_lazy( (output_dir / prefix).with_suffix(f".{idx}.jsonl.gz") ) as writer: while written < chunk_size: - item = next(items) writer.write(item) written += 1 + item = next(items) split_idx += 1 except StopIteration: break diff --git a/test/test_manipulation.py b/test/test_manipulation.py index 6f692ed9c..9afd04dfe 100644 --- a/test/test_manipulation.py +++ b/test/test_manipulation.py @@ -113,6 +113,17 @@ def test_split_lazy_even(manifest_type): ) +def test_split_lazy_edge_case_extra_shard(tmp_path): + N = 512 + chsz = 32 + nshrd = 16 + manifest = DummyManifest(CutSet, begin_id=0, end_id=N - 1) + manifest_subsets = manifest.split_lazy(output_dir=tmp_path, chunk_size=chsz) + assert len(manifest_subsets) == nshrd + for item in sorted(tmp_path.glob("*")): + print(item) + + @mark.parametrize("manifest_type", [RecordingSet, SupervisionSet, FeatureSet, CutSet]) def test_combine(manifest_type): expected = DummyManifest(manifest_type, begin_id=0, end_id=200)