diff --git a/lhotse/features/io.py b/lhotse/features/io.py index bfcf268a9..4a7b812ab 100644 --- a/lhotse/features/io.py +++ b/lhotse/features/io.py @@ -280,7 +280,10 @@ def write(self, key: str, value: np.ndarray) -> str: # too many files in a single directory. subdir = self.storage_path_ / key[:3] subdir.mkdir(exist_ok=True) - output_features_path = (subdir / key).with_suffix(".llc") + p = subdir / key + output_features_path = p.with_suffix( + p.suffix + ".llc" if p.suffix != ".llc" else ".llc" + ) serialized_feats = lilcom.compress(value, tick_power=self.tick_power) with open(output_features_path, "wb") as f: f.write(serialized_feats) @@ -343,7 +346,10 @@ def write(self, key: str, value: np.ndarray) -> str: # too many files in a single directory. subdir = self.storage_path_ / key[:3] subdir.mkdir(exist_ok=True) - output_features_path = (subdir / key).with_suffix(".npy") + p = subdir / key + output_features_path = p.with_suffix( + p.suffix + ".npy" if p.suffix != ".npy" else ".npy" + ) np.save(output_features_path, value, allow_pickle=False) # Include sub-directory in the key, e.g. "abc/abcdef.npy" return "/".join(output_features_path.parts[-2:]) @@ -450,7 +456,10 @@ def __init__(self, storage_path: Pathlike, mode: str = "w", *args, **kwargs): check_h5py_installed() import h5py - self.storage_path_ = Path(storage_path).with_suffix(".h5") + p = Path(storage_path) + self.storage_path_ = p.with_suffix( + p.suffix + ".h5" if p.suffix != ".h5" else ".h5" + ) self.hdf = h5py.File(self.storage_path, mode=mode) @property @@ -539,7 +548,10 @@ def __init__( check_h5py_installed() import h5py - self.storage_path_ = Path(storage_path).with_suffix(".h5") + p = Path(storage_path) + self.storage_path_ = p.with_suffix( + p.suffix + ".h5" if p.suffix != ".h5" else ".h5" + ) self.hdf = h5py.File(self.storage_path, mode=mode) self.tick_power = tick_power @@ -664,7 +676,10 @@ def __init__( check_h5py_installed() import h5py - self.storage_path_ = Path(storage_path).with_suffix(".h5") + p = Path(storage_path) + self.storage_path_ = p.with_suffix( + p.suffix + ".h5" if p.suffix != ".h5" else ".h5" + ) self.tick_power = tick_power self.chunk_size = chunk_size self.hdf = h5py.File(self.storage_path, mode=mode) @@ -826,7 +841,10 @@ def __init__( assert mode in ("wb", "ab") # ".lca" -> "lilcom chunky archive" - self.storage_path_ = Path(storage_path).with_suffix(".lca") + p = Path(storage_path) + self.storage_path_ = p.with_suffix( + p.suffix + ".lca" if p.suffix != ".lca" else ".lca" + ) self.tick_power = tick_power self.file = open(self.storage_path, mode=mode) self.curr_offset = self.file.tell() diff --git a/test/features/test_feature_writer.py b/test/features/test_feature_writer.py new file mode 100644 index 000000000..83b07cb24 --- /dev/null +++ b/test/features/test_feature_writer.py @@ -0,0 +1,81 @@ +from pathlib import Path +from tempfile import TemporaryDirectory + +import numpy as np +import pytest + +from lhotse import ( + ChunkedLilcomHdf5Writer, + LilcomChunkyWriter, + LilcomFilesWriter, + LilcomHdf5Writer, + NumpyFilesWriter, + NumpyHdf5Writer, +) +from lhotse.utils import is_module_available + + +@pytest.mark.parametrize( + ["writer_type", "ext"], + [ + (LilcomFilesWriter, ".llc"), + (NumpyFilesWriter, ".npy"), + ], +) +def test_writer_saved_file(writer_type, ext): + # Generate small random numbers that are nicely compressed with lilcom + arr = np.log(np.random.uniform(size=(11, 80)).astype(np.float32) / 100) + + with TemporaryDirectory() as d, writer_type(d) as writer: + # testing that words after . is not replace + input_key = "random0.3_vad.alpha" + key = writer.write(input_key, arr) + assert key == f"ran/{input_key}{ext}" + + # Testing when end with extension it is not added again + input_key = f"temp0.2.alpha{ext}" + key = writer.write(input_key, arr) + assert key == f"tem/{input_key}" + + +@pytest.mark.parametrize( + ["writer_type", "ext"], + [ + pytest.param( + NumpyHdf5Writer, + ".h5", + marks=pytest.mark.skipif( + not is_module_available("h5py"), + reason="Requires h5py to run HDF5 tests.", + ), + ), + pytest.param( + LilcomHdf5Writer, + ".h5", + marks=pytest.mark.skipif( + not is_module_available("h5py"), + reason="Requires h5py to run HDF5 tests.", + ), + ), + pytest.param( + ChunkedLilcomHdf5Writer, + ".h5", + marks=pytest.mark.skipif( + not is_module_available("h5py"), + reason="Requires h5py to run HDF5 tests.", + ), + ), + (LilcomChunkyWriter, ".lca"), + ], +) +def test_chunk_writer_saved_file(writer_type, ext): + with TemporaryDirectory() as d: + # testing that words after . is not replace + filename = "random0.3_vad.alpha" + with writer_type(f"{d}/{filename}") as writer: + assert writer.storage_path_ == Path(f"{d}/{filename}{ext}") + + # Testing when end with extension it is not added again + filename = f"random0.3_vad.alpha{ext}" + with writer_type(f"{d}/{filename}") as writer: + assert writer.storage_path_ == Path(f"{d}/{filename}")