Skip to content

Commit

Permalink
fix: ../../TEMPLATE/asr1/pyscripts/audio/format_wav_scp.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kamo-naoyuki committed Apr 28, 2023
1 parent 866c801 commit 88dbe01
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 24 deletions.
34 changes: 15 additions & 19 deletions egs2/TEMPLATE/asr1/pyscripts/audio/format_wav_scp.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,8 @@ def generator(self):

cached = {}
for utt, (recodeid, st, et) in self.segments_dict.items():
wavpath = self.wav_dict[recodeid]
if recodeid not in cached:
wavpath = self.wav_dict[recodeid]

if wavpath.endswith("|"):
if self.multi_columns:
raise RuntimeError(
Expand All @@ -117,37 +116,33 @@ def generator(self):
# Streaming input e.g. cat a.wav |
with kaldiio.open_like_kaldi(wavpath, "rb") as f:
with BytesIO(f.read()) as g:
retval = soundfile.read(g)
array, rate = soundfile.read(g)

else:
if self.multi_columns:
retval = soundfile_read(
array, rate = soundfile_read(
wavs=wavpath.split(),
dtype=None,
always_2d=False,
concat_axis=1,
)
else:
retval = soundfile.read(wavpath)

cached[recodeid] = retval
array, rate = soundfile.read(wavpath)
cached[recodeid] = array, rate

array, rate = cached[recodeid]
# Keep array until the last query
recodeid_counter[recodeid] -= 1
if recodeid_counter[recodeid] == 0:
cached.pop(recodeid)
# Convert starting time of the segment to corresponding sample number.
# If end time is -1 then use the whole file starting from start time.
if et != -1:
array = array[int(st * rate) : int(et * rate)]
else:
array = array[int(st * rate) :]

yield utt, self._return(retval, st, et), None, None

def _return(self, array, st, et):
if isinstance(array, (tuple, list)):
array, rate = array

# Convert starting time of the segment to corresponding sample number.
# If end time is -1 then use the whole file starting from start time.
if et != -1:
return array[int(st * rate) : int(et * rate)], rate
else:
return array[int(st * rate) :], rate
yield utt, (array, rate), None, None


def main():
Expand Down Expand Up @@ -283,6 +278,7 @@ def generator():
dtype=None,
always_2d=False,
concat_axis=1,
return_subtype=True,
)
else:
with soundfile.SoundFile(wavpath) as sf:
Expand Down
22 changes: 17 additions & 5 deletions espnet2/fileio/sound_scp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ def soundfile_read(
dtype=None,
always_2d: bool = False,
concat_axis: int = 1,
start: int = 0,
end: int = None,
return_subtype: bool = False,
) -> Tuple[np.array, int]:
if isinstance(wavs, str):
wavs = [wavs]
Expand All @@ -24,11 +27,17 @@ def soundfile_read(
prev_wav = None
for wav in wavs:
with soundfile.SoundFile(wav) as f:
# for supporting half-precision training
f.seek(start)
if end is not None:
frames = end - start

Check warning on line 32 in espnet2/fileio/sound_scp.py

View check run for this annotation

Codecov / codecov/patch

espnet2/fileio/sound_scp.py#L32

Added line #L32 was not covered by tests
else:
frames = -1
if dtype == "float16":
array = f.read(dtype="float32", always_2d=always_2d).astype(dtype)
array = f.read(

Check warning on line 36 in espnet2/fileio/sound_scp.py

View check run for this annotation

Codecov / codecov/patch

espnet2/fileio/sound_scp.py#L36

Added line #L36 was not covered by tests
frames, dtype="float32", always_2d=always_2d,
).astype(dtype)
else:
array = f.read(dtype=dtype, always_2d=always_2d)
array = f.read(frames, dtype=dtype, always_2d=always_2d)
rate = f.samplerate
subtype = f.subtype
subtypes.append(subtype)
Expand Down Expand Up @@ -61,7 +70,10 @@ def soundfile_read(
else:
array = np.concatenate(arrays, axis=concat_axis)

return array, rate, subtypes
if return_subtype:
return array, rate, subtypes

Check warning on line 74 in espnet2/fileio/sound_scp.py

View check run for this annotation

Codecov / codecov/patch

espnet2/fileio/sound_scp.py#L74

Added line #L74 was not covered by tests
else:
return array, rate


class SoundScpReader(collections.abc.Mapping):
Expand Down Expand Up @@ -124,7 +136,7 @@ def __init__(
def __getitem__(self, key) -> Tuple[int, np.ndarray]:
wavs = self.data[key]

array, rate, _ = soundfile_read(
array, rate = soundfile_read(
wavs,
dtype=self.dtype,
always_2d=self.always_2d,
Expand Down

0 comments on commit 88dbe01

Please sign in to comment.