Skip to content

Commit

Permalink
Prevent accidental renaming when using with_suffix (#1192)
Browse files Browse the repository at this point in the history
* Prevent accidental renaming when using with_suffix

* Adding test for the fix

* Add black formatting and isort

* fix isort error
  • Loading branch information
chiiyeh committed Oct 25, 2023
1 parent a470186 commit 523b653
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 6 deletions.
30 changes: 24 additions & 6 deletions lhotse/features/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,10 @@ def write(self, key: str, value: np.ndarray) -> str:
# too many files in a single directory.
subdir = self.storage_path_ / key[:3]
subdir.mkdir(exist_ok=True)
output_features_path = (subdir / key).with_suffix(".llc")
p = subdir / key
output_features_path = p.with_suffix(
p.suffix + ".llc" if p.suffix != ".llc" else ".llc"
)
serialized_feats = lilcom.compress(value, tick_power=self.tick_power)
with open(output_features_path, "wb") as f:
f.write(serialized_feats)
Expand Down Expand Up @@ -343,7 +346,10 @@ def write(self, key: str, value: np.ndarray) -> str:
# too many files in a single directory.
subdir = self.storage_path_ / key[:3]
subdir.mkdir(exist_ok=True)
output_features_path = (subdir / key).with_suffix(".npy")
p = subdir / key
output_features_path = p.with_suffix(
p.suffix + ".npy" if p.suffix != ".npy" else ".npy"
)
np.save(output_features_path, value, allow_pickle=False)
# Include sub-directory in the key, e.g. "abc/abcdef.npy"
return "/".join(output_features_path.parts[-2:])
Expand Down Expand Up @@ -450,7 +456,10 @@ def __init__(self, storage_path: Pathlike, mode: str = "w", *args, **kwargs):
check_h5py_installed()
import h5py

self.storage_path_ = Path(storage_path).with_suffix(".h5")
p = Path(storage_path)
self.storage_path_ = p.with_suffix(
p.suffix + ".h5" if p.suffix != ".h5" else ".h5"
)
self.hdf = h5py.File(self.storage_path, mode=mode)

@property
Expand Down Expand Up @@ -539,7 +548,10 @@ def __init__(
check_h5py_installed()
import h5py

self.storage_path_ = Path(storage_path).with_suffix(".h5")
p = Path(storage_path)
self.storage_path_ = p.with_suffix(
p.suffix + ".h5" if p.suffix != ".h5" else ".h5"
)
self.hdf = h5py.File(self.storage_path, mode=mode)
self.tick_power = tick_power

Expand Down Expand Up @@ -664,7 +676,10 @@ def __init__(
check_h5py_installed()
import h5py

self.storage_path_ = Path(storage_path).with_suffix(".h5")
p = Path(storage_path)
self.storage_path_ = p.with_suffix(
p.suffix + ".h5" if p.suffix != ".h5" else ".h5"
)
self.tick_power = tick_power
self.chunk_size = chunk_size
self.hdf = h5py.File(self.storage_path, mode=mode)
Expand Down Expand Up @@ -826,7 +841,10 @@ def __init__(
assert mode in ("wb", "ab")

# ".lca" -> "lilcom chunky archive"
self.storage_path_ = Path(storage_path).with_suffix(".lca")
p = Path(storage_path)
self.storage_path_ = p.with_suffix(
p.suffix + ".lca" if p.suffix != ".lca" else ".lca"
)
self.tick_power = tick_power
self.file = open(self.storage_path, mode=mode)
self.curr_offset = self.file.tell()
Expand Down
81 changes: 81 additions & 0 deletions test/features/test_feature_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from pathlib import Path
from tempfile import TemporaryDirectory

import numpy as np
import pytest

from lhotse import (
ChunkedLilcomHdf5Writer,
LilcomChunkyWriter,
LilcomFilesWriter,
LilcomHdf5Writer,
NumpyFilesWriter,
NumpyHdf5Writer,
)
from lhotse.utils import is_module_available


@pytest.mark.parametrize(
["writer_type", "ext"],
[
(LilcomFilesWriter, ".llc"),
(NumpyFilesWriter, ".npy"),
],
)
def test_writer_saved_file(writer_type, ext):
# Generate small random numbers that are nicely compressed with lilcom
arr = np.log(np.random.uniform(size=(11, 80)).astype(np.float32) / 100)

with TemporaryDirectory() as d, writer_type(d) as writer:
# testing that words after . is not replace
input_key = "random0.3_vad.alpha"
key = writer.write(input_key, arr)
assert key == f"ran/{input_key}{ext}"

# Testing when end with extension it is not added again
input_key = f"temp0.2.alpha{ext}"
key = writer.write(input_key, arr)
assert key == f"tem/{input_key}"


@pytest.mark.parametrize(
["writer_type", "ext"],
[
pytest.param(
NumpyHdf5Writer,
".h5",
marks=pytest.mark.skipif(
not is_module_available("h5py"),
reason="Requires h5py to run HDF5 tests.",
),
),
pytest.param(
LilcomHdf5Writer,
".h5",
marks=pytest.mark.skipif(
not is_module_available("h5py"),
reason="Requires h5py to run HDF5 tests.",
),
),
pytest.param(
ChunkedLilcomHdf5Writer,
".h5",
marks=pytest.mark.skipif(
not is_module_available("h5py"),
reason="Requires h5py to run HDF5 tests.",
),
),
(LilcomChunkyWriter, ".lca"),
],
)
def test_chunk_writer_saved_file(writer_type, ext):
with TemporaryDirectory() as d:
# testing that words after . is not replace
filename = "random0.3_vad.alpha"
with writer_type(f"{d}/{filename}") as writer:
assert writer.storage_path_ == Path(f"{d}/{filename}{ext}")

# Testing when end with extension it is not added again
filename = f"random0.3_vad.alpha{ext}"
with writer_type(f"{d}/{filename}") as writer:
assert writer.storage_path_ == Path(f"{d}/{filename}")

0 comments on commit 523b653

Please sign in to comment.