Skip to content

Commit

Permalink
Fix Recording.to_dict() when transforms are dicts and transform pickl…
Browse files Browse the repository at this point in the history
…ing issues (#1355)

* Fix Recording.to_dict() when transforms are dicts

* fix

* fix pickling issues with transforms

* fix

* fix

* fix
  • Loading branch information
pzelasko committed Jun 24, 2024
1 parent 9930ae4 commit d1f94c0
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 7 deletions.
13 changes: 11 additions & 2 deletions lhotse/audio/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,9 @@ def _aslist(x):
def to_dict(self) -> dict:
d = asdict_nonull(self)
if self.transforms is not None:
d["transforms"] = [t.to_dict() for t in self.transforms]
d["transforms"] = [
t if isinstance(t, dict) else t.to_dict() for t in self.transforms
]
return d

def to_cut(self):
Expand Down Expand Up @@ -866,8 +868,15 @@ def resample(self, sampling_rate: int) -> "Recording":
@staticmethod
def from_dict(data: dict) -> "Recording":
raw_sources = data.pop("sources")
try:
transforms = data.pop("transforms")
transforms = [AudioTransform.from_dict(t) for t in transforms]
except KeyError:
transforms = None
return Recording(
sources=[AudioSource.from_dict(s) for s in raw_sources], **data
sources=[AudioSource.from_dict(s) for s in raw_sources],
transforms=transforms,
**data,
)


Expand Down
6 changes: 4 additions & 2 deletions lhotse/augmentation/rir.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -69,7 +69,9 @@ def to_dict(self) -> dict:
"normalize_output": self.normalize_output,
"early_only": self.early_only,
"rir_channels": list(self.rir_channels),
"rir_generator": self.rir_generator,
"rir_generator": self.rir_generator
if self.rir_generator is None or isinstance(self.rir_generator, dict)
else self.rir_generator.to_dict(),
},
}

Expand Down
5 changes: 4 additions & 1 deletion lhotse/augmentation/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from typing import List, Optional

import numpy as np
Expand Down Expand Up @@ -100,6 +100,9 @@ def __post_init__(self):
else np.random.default_rng()
)

def to_dict(self):
return asdict(self)

def __call__(self, nsource: int = 1) -> np.ndarray:
"""
:param nsource: number of sources (RIR filters) to simulate. Default: 1.
Expand Down
2 changes: 1 addition & 1 deletion lhotse/dataset/cut_transforms/perturb_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(

def __call__(self, cuts: CutSet) -> CutSet:
if self.random is None:
self.random = random
self.random = random.Random()
return CutSet.from_cuts(
cut.perturb_speed(
factor=self.random.choice(self.factors), affix_id=not self.preserve_id
Expand Down
2 changes: 1 addition & 1 deletion lhotse/dataset/cut_transforms/reverberate.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(

def __call__(self, cuts: CutSet) -> CutSet:
if self.random is None:
self.random = random
self.random = random.Random()
return CutSet.from_cuts(
cut.reverb_rir(
rir_recording=self.random.choice(self.rir_recordings)
Expand Down
10 changes: 10 additions & 0 deletions test/audio/test_recording_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from lhotse.audio import DurationMismatchError
from lhotse.audio.mixer import AudioMixer
from lhotse.augmentation import ReverbWithImpulseResponse
from lhotse.testing.dummies import DummyManifest
from lhotse.utils import INT16MAX, fastcopy, is_module_available
from lhotse.utils import nullcontext as does_not_raise
Expand Down Expand Up @@ -632,3 +633,12 @@ def test_memory_recording_dict_serialization():
rec_reconstructed = Recording.from_dict(data)
assert rec == rec_reconstructed
np.testing.assert_equal(rec_reconstructed.load_audio(), rec.load_audio())


def test_recording_to_dict_with_transform_dict():
path = "test/fixtures/mono_c0.wav"
recording = Recording.from_file(path)
recording = recording.reverb_rir()
serialized = recording.to_dict()
recording_restored = Recording.from_dict(serialized)
assert recording == recording_restored

0 comments on commit d1f94c0

Please sign in to comment.