Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reverb_rir: support Cut input and in memory data #1332

Merged
merged 2 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions lhotse/audio/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from io import BytesIO
from math import ceil, isclose
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -126,7 +126,7 @@ class Recording:
num_samples: int
duration: Seconds
channel_ids: Optional[List[int]] = None
transforms: Optional[List[Dict]] = None
transforms: Optional[List[Union[AudioTransform, Dict]]] = None

def __post_init__(self):
if self.channel_ids is None:
Expand Down Expand Up @@ -334,7 +334,10 @@ def _aslist(x):
)

def to_dict(self) -> dict:
return asdict_nonull(self)
d = asdict_nonull(self)
if self.transforms is not None:
d["transforms"] = [t.to_dict() for t in self.transforms]
return d

def to_cut(self):
"""
Expand Down Expand Up @@ -395,7 +398,8 @@ def load_audio(
)

transforms = [
AudioTransform.from_dict(params) for params in self.transforms or []
tnfm if isinstance(tnfm, AudioTransform) else AudioTransform.from_dict(tnfm)
for tnfm in self.transforms or []
]

# Do a "backward pass" over data augmentation transforms to get the
Expand Down Expand Up @@ -488,10 +492,15 @@ def load_video(
)

for t in ifnone(self.transforms, ()):
assert t["name"] not in (
"Speed",
"Tempo",
), "Recording.load_video() does not support speed/tempo perturbation."
if isinstance(t, dict):
assert t["name"] not in (
"Speed",
"Tempo",
), "Recording.load_video() does not support speed/tempo perturbation."
else:
assert not isinstance(
t, (Speed, Tempo)
), "Recording.load_video() does not support speed/tempo perturbation."

if not with_audio:
video, _ = self._video_source.load_video(
Expand Down Expand Up @@ -519,7 +528,8 @@ def load_video(
)

transforms = [
AudioTransform.from_dict(params) for params in self.transforms or []
tnfm if isinstance(tnfm, AudioTransform) else AudioTransform.from_dict(tnfm)
for tnfm in self.transforms or []
]

# Do a "backward pass" over data augmentation transforms to get the
Expand Down Expand Up @@ -659,7 +669,7 @@ def perturb_speed(self, factor: float, affix_id: bool = True) -> "Recording":
:return: a modified copy of the current ``Recording``.
"""
transforms = self.transforms.copy() if self.transforms is not None else []
transforms.append(Speed(factor=factor).to_dict())
transforms.append(Speed(factor=factor))
new_num_samples = perturb_num_samples(self.num_samples, factor)
new_duration = new_num_samples / self.sampling_rate
return fastcopy(
Expand All @@ -684,7 +694,7 @@ def perturb_tempo(self, factor: float, affix_id: bool = True) -> "Recording":
:return: a modified copy of the current ``Recording``.
"""
transforms = self.transforms.copy() if self.transforms is not None else []
transforms.append(Tempo(factor=factor).to_dict())
transforms.append(Tempo(factor=factor))
new_num_samples = perturb_num_samples(self.num_samples, factor)
new_duration = new_num_samples / self.sampling_rate
return fastcopy(
Expand All @@ -705,7 +715,7 @@ def perturb_volume(self, factor: float, affix_id: bool = True) -> "Recording":
:return: a modified copy of the current ``Recording``.
"""
transforms = self.transforms.copy() if self.transforms is not None else []
transforms.append(Volume(factor=factor).to_dict())
transforms.append(Volume(factor=factor))
return fastcopy(
self,
id=f"{self.id}_vp{factor}" if affix_id else self.id,
Expand All @@ -722,7 +732,7 @@ def normalize_loudness(self, target: float, affix_id: bool = False) -> "Recordin
:return: a modified copy of the current ``Recording``.
"""
transforms = self.transforms.copy() if self.transforms is not None else []
transforms.append(LoudnessNormalization(target=target).to_dict())
transforms.append(LoudnessNormalization(target=target))
return fastcopy(
self,
id=f"{self.id}_ln{target}" if affix_id else self.id,
Expand All @@ -738,7 +748,7 @@ def dereverb_wpe(self, affix_id: bool = True) -> "Recording":
:return: a modified copy of the current ``Recording``.
"""
transforms = self.transforms.copy() if self.transforms is not None else []
transforms.append(DereverbWPE().to_dict())
transforms.append(DereverbWPE())
return fastcopy(
self,
id=f"{self.id}_wpe" if affix_id else self.id,
Expand All @@ -751,7 +761,7 @@ def reverb_rir(
normalize_output: bool = True,
early_only: bool = False,
affix_id: bool = True,
rir_channels: Optional[List[int]] = None,
rir_channels: Optional[Sequence[int]] = None,
room_rng_seed: Optional[int] = None,
source_rng_seed: Optional[int] = None,
) -> "Recording":
Expand Down Expand Up @@ -812,7 +822,7 @@ def reverb_rir(
early_only=early_only,
rir_channels=rir_channels if rir_channels is not None else [0],
rir_generator=rir_generator,
).to_dict()
)
)
return fastcopy(
self,
Expand All @@ -835,7 +845,7 @@ def resample(self, sampling_rate: int) -> "Recording":
Resample(
source_sampling_rate=self.sampling_rate,
target_sampling_rate=sampling_rate,
).to_dict()
)
)

new_num_samples = compute_num_samples(
Expand Down
39 changes: 31 additions & 8 deletions lhotse/augmentation/rir.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,14 @@ class ReverbWithImpulseResponse(AudioTransform):

def __post_init__(self):
if isinstance(self.rir, dict):
from lhotse import Recording
from lhotse.serialization import deserialize_item

# Pass a shallow copy of the RIR dict since `from_dict()` pops the `sources` key.
self.rir = Recording.from_dict(self.rir.copy())
# Pass a shallow copy of the RIR dict since deserialization is destructive
# If RIR is a Cut, we have to perform one extra copy (hacky but better than deepcopy).
rir = self.rir.copy()
if "recording" in self.rir:
rir["recording"] = rir["recording"].copy()
self.rir = deserialize_item(rir)

assert (
self.rir is not None or self.rir_generator is not None
Expand All @@ -52,6 +56,23 @@ def __post_init__(self):
if self.rir_generator is not None and isinstance(self.rir_generator, dict):
self.rir_generator = FastRandomRIRGenerator(**self.rir_generator)

def to_dict(self) -> dict:
from lhotse import Recording
from lhotse.cut import Cut

return {
"name": type(self).__name__,
"kwargs": {
"rir": self.rir.to_dict()
if isinstance(self.rir, (Recording, Cut))
else self.rir,
"normalize_output": self.normalize_output,
"early_only": self.early_only,
"rir_channels": list(self.rir_channels),
"rir_generator": self.rir_generator,
},
}

def __call__(
self,
samples: np.ndarray,
Expand Down Expand Up @@ -92,11 +113,13 @@ def __call__(
if self.rir is None:
rir_ = self.rir_generator(nsource=1)
else:
rir_ = (
self.rir.load_audio(channels=self.rir_channels)
if not self.early_only
else self.rir.load_audio(channels=self.rir_channels, duration=0.05)
)
from lhotse import Recording

rir = self.rir.to_cut() if isinstance(self.rir, Recording) else self.rir
rir = rir.with_channels(self.rir_channels)
if self.early_only:
rir = rir.truncate(duration=0.05)
rir_ = rir.load_audio()

D_rir, N_rir = rir_.shape
N_out = N_in # Enforce shift output
Expand Down
1 change: 0 additions & 1 deletion lhotse/cut/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
compute_start_duration_for_extended_cut,
fastcopy,
ifnone,
is_torchaudio_available,
overlaps,
to_hashable,
)
Expand Down
11 changes: 11 additions & 0 deletions lhotse/cut/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Seconds,
TimeSpan,
add_durations,
asdict_nonull,
compute_num_frames,
compute_num_samples,
fastcopy,
Expand Down Expand Up @@ -70,6 +71,16 @@ class DataCut(Cut, CustomFieldMixin, metaclass=ABCMeta):
# Store anything else the user might want.
custom: Optional[Dict[str, Any]] = None

def to_dict(self) -> dict:
d = asdict_nonull(self)
if self.has_recording:
d["recording"] = self.recording.to_dict()
if self.custom is not None:
for k, v in self.custom.items():
if isinstance(v, Recording):
d["custom"][k] = v.to_dict()
return {**d, "type": type(self).__name__}

@property
def recording_id(self) -> str:
return self.recording.id if self.has_recording else self.features.recording_id
Expand Down
52 changes: 49 additions & 3 deletions lhotse/cut/mono.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from functools import partial, reduce
from operator import add
from typing import Any, Callable, Iterable, List, Optional, Tuple
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand All @@ -16,6 +16,7 @@
add_durations,
fastcopy,
hash_str_to_int,
is_equal_or_contains,
merge_items_with_delimiter,
overlaps,
rich_exception_info,
Expand Down Expand Up @@ -102,13 +103,58 @@ def load_video(
)
return None

def with_channels(self, channels: Union[List[int], int]) -> DataCut:
"""
Select specified channels from this cut.
Supports extending to other channels available in the underlying :class:`Recording`.
If a single channel is provided, we'll return a :class:`~lhotse.cut.MonoCut`,
otherwise we'll return a :class:`~lhotse.cut.MultiCut`.
"""
channel_is_int = isinstance(channels, int)
assert set([channels] if channel_is_int else channels).issubset(
set(self.recording.channel_ids)
), f"Cannot select {channels=} because they are not a subset of {self.recording.channel_ids=}"

mono = channel_is_int or len(channels) == 1
if mono:
if not channel_is_int:
(channels,) = channels
return MonoCut(
id=f"{self.id}-{channels}",
recording=self.recording,
start=self.start,
duration=self.duration,
channel=channels,
supervisions=[
fastcopy(s, channel=channels)
for s in self.supervisions
if is_equal_or_contains(s.channel, channels)
],
custom=self.custom,
)
else:
from lhotse import MultiCut

return MultiCut(
id=f"{self.id}-{len(channels)}chan",
start=self.start,
duration=self.duration,
channel=channels,
supervisions=[
s
for s in self.supervisions
if is_equal_or_contains(channels, s.channel)
],
custom=self.custom,
)

def reverb_rir(
self,
rir_recording: Optional["Recording"] = None,
rir_recording: Optional[Union[Recording, DataCut]] = None,
normalize_output: bool = True,
early_only: bool = False,
affix_id: bool = True,
rir_channels: List[int] = [0],
rir_channels: Sequence[int] = (0,),
room_rng_seed: Optional[int] = None,
source_rng_seed: Optional[int] = None,
) -> DataCut:
Expand Down
13 changes: 7 additions & 6 deletions lhotse/cut/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,11 @@ def load_video(

def reverb_rir(
self,
rir_recording: Optional["Recording"] = None,
rir_recording: Optional[Union[Recording, DataCut]] = None,
normalize_output: bool = True,
early_only: bool = False,
affix_id: bool = True,
rir_channels: List[int] = [0],
rir_channels: Sequence[int] = (0,),
room_rng_seed: Optional[int] = None,
source_rng_seed: Optional[int] = None,
) -> "MultiCut":
Expand Down Expand Up @@ -370,17 +370,18 @@ def with_channels(self, channels: Union[List[int], int]) -> DataCut:
Select specified channels from this cut.
Supports extending to other channels available in the underlying :class:`Recording`.
If a single channel is provided, we'll return a :class:`~lhotse.cut.MonoCut`,
otherwise we'll return a :class:`~lhotse.cut.MultiCut'.
otherwise we'll return a :class:`~lhotse.cut.MultiCut`.
"""
mono = isinstance(channels, int) or len(channels) == 1
assert set([channels] if mono else channels).issubset(
channel_is_int = isinstance(channels, int)
assert set([channels] if channel_is_int else channels).issubset(
set(self.recording.channel_ids)
), f"Cannot select {channels=} because they are not a subset of {self.recording.channel_ids=}"

mono = channel_is_int or len(channels) == 1
if mono:
from .mono import MonoCut

if isinstance(channels, Sequence):
if not channel_is_int:
(channels,) = channels
return MonoCut(
id=f"{self.id}-{channels}",
Expand Down
Loading
Loading