diff --git a/lhotse/audio/recording.py b/lhotse/audio/recording.py index 0c3ea2ebc..971166995 100644 --- a/lhotse/audio/recording.py +++ b/lhotse/audio/recording.py @@ -2,7 +2,7 @@ from io import BytesIO from math import ceil, isclose from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -126,7 +126,7 @@ class Recording: num_samples: int duration: Seconds channel_ids: Optional[List[int]] = None - transforms: Optional[List[Dict]] = None + transforms: Optional[List[Union[AudioTransform, Dict]]] = None def __post_init__(self): if self.channel_ids is None: @@ -334,7 +334,10 @@ def _aslist(x): ) def to_dict(self) -> dict: - return asdict_nonull(self) + d = asdict_nonull(self) + if self.transforms is not None: + d["transforms"] = [t.to_dict() for t in self.transforms] + return d def to_cut(self): """ @@ -395,7 +398,8 @@ def load_audio( ) transforms = [ - AudioTransform.from_dict(params) for params in self.transforms or [] + tnfm if isinstance(tnfm, AudioTransform) else AudioTransform.from_dict(tnfm) + for tnfm in self.transforms or [] ] # Do a "backward pass" over data augmentation transforms to get the @@ -488,10 +492,15 @@ def load_video( ) for t in ifnone(self.transforms, ()): - assert t["name"] not in ( - "Speed", - "Tempo", - ), "Recording.load_video() does not support speed/tempo perturbation." + if isinstance(t, dict): + assert t["name"] not in ( + "Speed", + "Tempo", + ), "Recording.load_video() does not support speed/tempo perturbation." + else: + assert not isinstance( + t, (Speed, Tempo) + ), "Recording.load_video() does not support speed/tempo perturbation." if not with_audio: video, _ = self._video_source.load_video( @@ -519,7 +528,8 @@ def load_video( ) transforms = [ - AudioTransform.from_dict(params) for params in self.transforms or [] + tnfm if isinstance(tnfm, AudioTransform) else AudioTransform.from_dict(tnfm) + for tnfm in self.transforms or [] ] # Do a "backward pass" over data augmentation transforms to get the @@ -659,7 +669,7 @@ def perturb_speed(self, factor: float, affix_id: bool = True) -> "Recording": :return: a modified copy of the current ``Recording``. """ transforms = self.transforms.copy() if self.transforms is not None else [] - transforms.append(Speed(factor=factor).to_dict()) + transforms.append(Speed(factor=factor)) new_num_samples = perturb_num_samples(self.num_samples, factor) new_duration = new_num_samples / self.sampling_rate return fastcopy( @@ -684,7 +694,7 @@ def perturb_tempo(self, factor: float, affix_id: bool = True) -> "Recording": :return: a modified copy of the current ``Recording``. """ transforms = self.transforms.copy() if self.transforms is not None else [] - transforms.append(Tempo(factor=factor).to_dict()) + transforms.append(Tempo(factor=factor)) new_num_samples = perturb_num_samples(self.num_samples, factor) new_duration = new_num_samples / self.sampling_rate return fastcopy( @@ -705,7 +715,7 @@ def perturb_volume(self, factor: float, affix_id: bool = True) -> "Recording": :return: a modified copy of the current ``Recording``. """ transforms = self.transforms.copy() if self.transforms is not None else [] - transforms.append(Volume(factor=factor).to_dict()) + transforms.append(Volume(factor=factor)) return fastcopy( self, id=f"{self.id}_vp{factor}" if affix_id else self.id, @@ -722,7 +732,7 @@ def normalize_loudness(self, target: float, affix_id: bool = False) -> "Recordin :return: a modified copy of the current ``Recording``. """ transforms = self.transforms.copy() if self.transforms is not None else [] - transforms.append(LoudnessNormalization(target=target).to_dict()) + transforms.append(LoudnessNormalization(target=target)) return fastcopy( self, id=f"{self.id}_ln{target}" if affix_id else self.id, @@ -738,7 +748,7 @@ def dereverb_wpe(self, affix_id: bool = True) -> "Recording": :return: a modified copy of the current ``Recording``. """ transforms = self.transforms.copy() if self.transforms is not None else [] - transforms.append(DereverbWPE().to_dict()) + transforms.append(DereverbWPE()) return fastcopy( self, id=f"{self.id}_wpe" if affix_id else self.id, @@ -751,7 +761,7 @@ def reverb_rir( normalize_output: bool = True, early_only: bool = False, affix_id: bool = True, - rir_channels: Optional[List[int]] = None, + rir_channels: Optional[Sequence[int]] = None, room_rng_seed: Optional[int] = None, source_rng_seed: Optional[int] = None, ) -> "Recording": @@ -812,7 +822,7 @@ def reverb_rir( early_only=early_only, rir_channels=rir_channels if rir_channels is not None else [0], rir_generator=rir_generator, - ).to_dict() + ) ) return fastcopy( self, @@ -835,7 +845,7 @@ def resample(self, sampling_rate: int) -> "Recording": Resample( source_sampling_rate=self.sampling_rate, target_sampling_rate=sampling_rate, - ).to_dict() + ) ) new_num_samples = compute_num_samples( diff --git a/lhotse/augmentation/rir.py b/lhotse/augmentation/rir.py index 09f5dc38c..bf2e8a1a5 100644 --- a/lhotse/augmentation/rir.py +++ b/lhotse/augmentation/rir.py @@ -35,10 +35,14 @@ class ReverbWithImpulseResponse(AudioTransform): def __post_init__(self): if isinstance(self.rir, dict): - from lhotse import Recording + from lhotse.serialization import deserialize_item - # Pass a shallow copy of the RIR dict since `from_dict()` pops the `sources` key. - self.rir = Recording.from_dict(self.rir.copy()) + # Pass a shallow copy of the RIR dict since deserialization is destructive + # If RIR is a Cut, we have to perform one extra copy (hacky but better than deepcopy). + rir = self.rir.copy() + if "recording" in self.rir: + rir["recording"] = rir["recording"].copy() + self.rir = deserialize_item(rir) assert ( self.rir is not None or self.rir_generator is not None @@ -52,6 +56,23 @@ def __post_init__(self): if self.rir_generator is not None and isinstance(self.rir_generator, dict): self.rir_generator = FastRandomRIRGenerator(**self.rir_generator) + def to_dict(self) -> dict: + from lhotse import Recording + from lhotse.cut import Cut + + return { + "name": type(self).__name__, + "kwargs": { + "rir": self.rir.to_dict() + if isinstance(self.rir, (Recording, Cut)) + else self.rir, + "normalize_output": self.normalize_output, + "early_only": self.early_only, + "rir_channels": list(self.rir_channels), + "rir_generator": self.rir_generator, + }, + } + def __call__( self, samples: np.ndarray, @@ -92,11 +113,13 @@ def __call__( if self.rir is None: rir_ = self.rir_generator(nsource=1) else: - rir_ = ( - self.rir.load_audio(channels=self.rir_channels) - if not self.early_only - else self.rir.load_audio(channels=self.rir_channels, duration=0.05) - ) + from lhotse import Recording + + rir = self.rir.to_cut() if isinstance(self.rir, Recording) else self.rir + rir = rir.with_channels(self.rir_channels) + if self.early_only: + rir = rir.truncate(duration=0.05) + rir_ = rir.load_audio() D_rir, N_rir = rir_.shape N_out = N_in # Enforce shift output diff --git a/lhotse/cut/base.py b/lhotse/cut/base.py index 7ab95d991..fa1400a4a 100644 --- a/lhotse/cut/base.py +++ b/lhotse/cut/base.py @@ -22,7 +22,6 @@ compute_start_duration_for_extended_cut, fastcopy, ifnone, - is_torchaudio_available, overlaps, to_hashable, ) diff --git a/lhotse/cut/data.py b/lhotse/cut/data.py index 438805eb3..7229353af 100644 --- a/lhotse/cut/data.py +++ b/lhotse/cut/data.py @@ -22,6 +22,7 @@ Seconds, TimeSpan, add_durations, + asdict_nonull, compute_num_frames, compute_num_samples, fastcopy, @@ -70,6 +71,16 @@ class DataCut(Cut, CustomFieldMixin, metaclass=ABCMeta): # Store anything else the user might want. custom: Optional[Dict[str, Any]] = None + def to_dict(self) -> dict: + d = asdict_nonull(self) + if self.has_recording: + d["recording"] = self.recording.to_dict() + if self.custom is not None: + for k, v in self.custom.items(): + if isinstance(v, Recording): + d["custom"][k] = v.to_dict() + return {**d, "type": type(self).__name__} + @property def recording_id(self) -> str: return self.recording.id if self.has_recording else self.features.recording_id diff --git a/lhotse/cut/mono.py b/lhotse/cut/mono.py index 1cecfb2c0..75bbe23f7 100644 --- a/lhotse/cut/mono.py +++ b/lhotse/cut/mono.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from functools import partial, reduce from operator import add -from typing import Any, Callable, Iterable, List, Optional, Tuple +from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -16,6 +16,7 @@ add_durations, fastcopy, hash_str_to_int, + is_equal_or_contains, merge_items_with_delimiter, overlaps, rich_exception_info, @@ -102,13 +103,58 @@ def load_video( ) return None + def with_channels(self, channels: Union[List[int], int]) -> DataCut: + """ + Select specified channels from this cut. + Supports extending to other channels available in the underlying :class:`Recording`. + If a single channel is provided, we'll return a :class:`~lhotse.cut.MonoCut`, + otherwise we'll return a :class:`~lhotse.cut.MultiCut`. + """ + channel_is_int = isinstance(channels, int) + assert set([channels] if channel_is_int else channels).issubset( + set(self.recording.channel_ids) + ), f"Cannot select {channels=} because they are not a subset of {self.recording.channel_ids=}" + + mono = channel_is_int or len(channels) == 1 + if mono: + if not channel_is_int: + (channels,) = channels + return MonoCut( + id=f"{self.id}-{channels}", + recording=self.recording, + start=self.start, + duration=self.duration, + channel=channels, + supervisions=[ + fastcopy(s, channel=channels) + for s in self.supervisions + if is_equal_or_contains(s.channel, channels) + ], + custom=self.custom, + ) + else: + from lhotse import MultiCut + + return MultiCut( + id=f"{self.id}-{len(channels)}chan", + start=self.start, + duration=self.duration, + channel=channels, + supervisions=[ + s + for s in self.supervisions + if is_equal_or_contains(channels, s.channel) + ], + custom=self.custom, + ) + def reverb_rir( self, - rir_recording: Optional["Recording"] = None, + rir_recording: Optional[Union[Recording, DataCut]] = None, normalize_output: bool = True, early_only: bool = False, affix_id: bool = True, - rir_channels: List[int] = [0], + rir_channels: Sequence[int] = (0,), room_rng_seed: Optional[int] = None, source_rng_seed: Optional[int] = None, ) -> DataCut: diff --git a/lhotse/cut/multi.py b/lhotse/cut/multi.py index 439f9d217..3544b7c4d 100644 --- a/lhotse/cut/multi.py +++ b/lhotse/cut/multi.py @@ -157,11 +157,11 @@ def load_video( def reverb_rir( self, - rir_recording: Optional["Recording"] = None, + rir_recording: Optional[Union[Recording, DataCut]] = None, normalize_output: bool = True, early_only: bool = False, affix_id: bool = True, - rir_channels: List[int] = [0], + rir_channels: Sequence[int] = (0,), room_rng_seed: Optional[int] = None, source_rng_seed: Optional[int] = None, ) -> "MultiCut": @@ -370,17 +370,18 @@ def with_channels(self, channels: Union[List[int], int]) -> DataCut: Select specified channels from this cut. Supports extending to other channels available in the underlying :class:`Recording`. If a single channel is provided, we'll return a :class:`~lhotse.cut.MonoCut`, - otherwise we'll return a :class:`~lhotse.cut.MultiCut'. + otherwise we'll return a :class:`~lhotse.cut.MultiCut`. """ - mono = isinstance(channels, int) or len(channels) == 1 - assert set([channels] if mono else channels).issubset( + channel_is_int = isinstance(channels, int) + assert set([channels] if channel_is_int else channels).issubset( set(self.recording.channel_ids) ), f"Cannot select {channels=} because they are not a subset of {self.recording.channel_ids=}" + mono = channel_is_int or len(channels) == 1 if mono: from .mono import MonoCut - if isinstance(channels, Sequence): + if not channel_is_int: (channels,) = channels return MonoCut( id=f"{self.id}-{channels}", diff --git a/test/cut/test_cut_augmentation.py b/test/cut/test_cut_augmentation.py index 133321ce5..9219cebbc 100644 --- a/test/cut/test_cut_augmentation.py +++ b/test/cut/test_cut_augmentation.py @@ -1,3 +1,6 @@ +import os +from tempfile import NamedTemporaryFile + import numpy as np import pytest import torch @@ -6,7 +9,7 @@ from lhotse.audio import RecordingSet from lhotse.cut import PaddingCut from lhotse.testing.dummies import dummy_cut, dummy_multi_cut -from lhotse.utils import fastcopy, is_module_available +from lhotse.utils import fastcopy, is_module_available, nullcontext @pytest.fixture @@ -652,9 +655,14 @@ def test_cut_normalize_loudness(libri_cut_set, target, mix_first): assert loudness == pytest.approx(target, abs=0.5) -def test_cut_reverb_rir(libri_cut_with_supervision, libri_recording_rvb, rir): +@pytest.mark.parametrize("in_memory", [True, False]) +def test_cut_reverb_rir( + libri_cut_with_supervision, libri_recording_rvb, rir, in_memory +): cut = libri_cut_with_supervision + if in_memory: + rir = rir.move_to_memory() cut_rvb = cut.reverb_rir(rir) assert cut_rvb.start == cut.start assert cut_rvb.duration == cut.duration @@ -676,6 +684,48 @@ def test_cut_reverb_rir(libri_cut_with_supervision, libri_recording_rvb, rir): np.testing.assert_array_almost_equal(cut_rvb.load_audio(), rvb_audio_from_fixture) +@pytest.mark.parametrize("with_serialization", [True, False]) +def test_cut_reverb_rir_input_is_cut( + libri_cut_with_supervision, libri_recording_rvb, rir, with_serialization +): + + cut = libri_cut_with_supervision + rir = rir.to_cut() + + with ( + NamedTemporaryFile(suffix=".jsonl", mode="w") + if with_serialization + else nullcontext() + ) as f: + if with_serialization: + CutSet([rir]).to_file(f.name) + f.flush() + os.fsync(f.fileno()) + rir = CutSet.from_file(f.name)[0] + + cut_rvb = cut.reverb_rir(rir) + assert cut_rvb.start == cut.start + assert cut_rvb.duration == cut.duration + assert cut_rvb.end == cut.end + assert cut_rvb.num_samples == cut.num_samples + + assert cut_rvb.recording.duration == cut.recording.duration + assert cut_rvb.recording.num_samples == cut.recording.num_samples + + assert cut_rvb.supervisions[0].start == cut.supervisions[0].start + assert cut_rvb.supervisions[0].duration == cut.supervisions[0].duration + assert cut_rvb.supervisions[0].end == cut.supervisions[0].end + + assert cut_rvb.load_audio().shape == cut.load_audio().shape + assert cut_rvb.recording.load_audio().shape == cut.recording.load_audio().shape + + rvb_audio_from_fixture = libri_recording_rvb.load_audio() + + np.testing.assert_array_almost_equal( + cut_rvb.load_audio(), rvb_audio_from_fixture + ) + + def test_cut_reverb_rir_assert_sampling_rate(libri_cut_with_supervision, rir): cut = libri_cut_with_supervision rir_new = rir.resample(8000)