From 455b20e41ad8da21e16861783f7a2b2342d54c27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 31 Jan 2024 14:18:58 -0500 Subject: [PATCH] Allow duplicate cut IDs in a CutSet (CutSet is list-like instead of dict-like) (#1279) * Allow duplicate cut IDs in a CutSet (CutSet is list-like instead of dict-like) * Remove BaseIterable altogether (was renamed from ImitatesDict and is no longer needed) * cleanup duplicate checking fn --- lhotse/audio/recording_set.py | 50 ++++----- lhotse/cut/set.py | 103 ++++++++++-------- lhotse/lazy.py | 52 +++------ lhotse/shar/readers/lazy.py | 10 +- lhotse/supervision.py | 44 ++++---- lhotse/utils.py | 8 -- .../meeting_simulation/conversational.py | 2 +- .../meeting_simulation/speaker_independent.py | 2 +- test/cut/test_cut.py | 28 ++--- test/cut/test_cut_set.py | 13 ++- test/test_lazy.py | 6 +- 11 files changed, 150 insertions(+), 168 deletions(-) diff --git a/lhotse/audio/recording_set.py b/lhotse/audio/recording_set.py index bdb9cc25c..e26ab8c52 100644 --- a/lhotse/audio/recording_set.py +++ b/lhotse/audio/recording_set.py @@ -18,7 +18,6 @@ Seconds, exactly_one_not_null, ifnone, - index_by_id_and_check, split_manifest_lazy, split_sequence, ) @@ -26,7 +25,7 @@ class RecordingSet(Serializable, AlgorithmMixin): """ - :class:`~lhotse.audio.RecordingSet` represents a collection of recordings, indexed by recording IDs. + :class:`~lhotse.audio.RecordingSet` represents a collection of recordings. It does not contain any annotation such as the transcript or the speaker identity -- just the information needed to retrieve a recording such as its path, URL, number of channels, and some recording metadata (duration, number of samples). @@ -86,7 +85,7 @@ class RecordingSet(Serializable, AlgorithmMixin): >>> recs_24k = recs.resample(24000) """ - def __init__(self, recordings: Optional[Mapping[str, Recording]] = None) -> None: + def __init__(self, recordings: Optional[Iterable[Recording]] = None) -> None: self.recordings = ifnone(recordings, {}) def __eq__(self, other: "RecordingSet") -> bool: @@ -99,11 +98,11 @@ def data(self) -> Union[Dict[str, Recording], Iterable[Recording]]: @property def ids(self) -> Iterable[str]: - return self.recordings.keys() + return (r.id for r in self) @staticmethod def from_recordings(recordings: Iterable[Recording]) -> "RecordingSet": - return RecordingSet(recordings=index_by_id_and_check(recordings)) + return RecordingSet(list(recordings)) from_items = from_recordings @@ -254,7 +253,7 @@ def load_audio( offset_seconds: float = 0.0, duration_seconds: Optional[float] = None, ) -> np.ndarray: - return self.recordings[recording_id].load_audio( + return self[recording_id].load_audio( channels=channels, offset=offset_seconds, duration=duration_seconds ) @@ -262,16 +261,16 @@ def with_path_prefix(self, path: Pathlike) -> "RecordingSet": return RecordingSet.from_recordings(r.with_path_prefix(path) for r in self) def num_channels(self, recording_id: str) -> int: - return self.recordings[recording_id].num_channels + return self[recording_id].num_channels def sampling_rate(self, recording_id: str) -> int: - return self.recordings[recording_id].sampling_rate + return self[recording_id].sampling_rate def num_samples(self, recording_id: str) -> int: - return self.recordings[recording_id].num_samples + return self[recording_id].num_samples def duration(self, recording_id: str) -> Seconds: - return self.recordings[recording_id].duration + return self[recording_id].duration def perturb_speed(self, factor: float, affix_id: bool = True) -> "RecordingSet": """ @@ -368,24 +367,25 @@ def resample(self, sampling_rate: int) -> "RecordingSet": def __repr__(self) -> str: return f"RecordingSet(len={len(self)})" - def __contains__(self, item: Union[str, Recording]) -> bool: - if isinstance(item, str): - return item in self.recordings + def __getitem__(self, index_or_id: Union[int, str]) -> Recording: + try: + return self.recordings[index_or_id] # int passed, eager manifest, fast + except TypeError: + # either lazy manifest or str passed, both are slow + if self.is_lazy: + return next(item for idx, item in enumerate(self) if idx == index_or_id) + else: + # string id passed, support just for backward compatibility, not recommended + return next(item for item in self if item.id == index_or_id) + + def __contains__(self, other: Union[str, Recording]) -> bool: + if isinstance(other, str): + return any(other == item.id for item in self) else: - return item.id in self.recordings - - def __getitem__(self, recording_id_or_index: Union[int, str]) -> Recording: - if isinstance(recording_id_or_index, str): - return self.recordings[recording_id_or_index] - # ~100x faster than list(dict.values())[index] for 100k elements - return next( - val - for idx, val in enumerate(self.recordings.values()) - if idx == recording_id_or_index - ) + return any(other.id == item.id for item in self) def __iter__(self) -> Iterable[Recording]: - return iter(self.recordings.values()) + yield from self.recordings def __len__(self) -> int: return len(self.recordings) diff --git a/lhotse/cut/set.py b/lhotse/cut/set.py index 737376f0f..440bf09f2 100644 --- a/lhotse/cut/set.py +++ b/lhotse/cut/set.py @@ -17,7 +17,6 @@ Iterable, List, Literal, - Mapping, Optional, Sequence, Set, @@ -45,7 +44,7 @@ from lhotse.features.io import FeaturesWriter, LilcomChunkyWriter from lhotse.lazy import ( AlgorithmMixin, - ImitatesDict, + Dillable, LazyFlattener, LazyIteratorChain, LazyManifestIterator, @@ -62,10 +61,10 @@ Seconds, compute_num_frames, compute_num_samples, + deprecated, exactly_one_not_null, fastcopy, ifnone, - index_by_id_and_check, split_manifest_lazy, split_sequence, uuid4, @@ -76,10 +75,15 @@ class CutSet(Serializable, AlgorithmMixin): """ - :class:`~lhotse.cut.CutSet` represents a collection of cuts, indexed by cut IDs. + :class:`~lhotse.cut.CutSet` represents a collection of cuts. CutSet ties together all types of data -- audio, features and supervisions, and is suitable to represent training/dev/test sets. + CutSet can be either "lazy" (acts as an iterable) which is best for representing full datasets, + or "eager" (acts as a list), which is best for representing individual mini-batches (and sometimes test/dev datasets). + Almost all operations are available for both modes, but some of them are more efficient depending on the mode + (e.g. indexing an "eager" manifest is O(1)). + .. note:: :class:`~lhotse.cut.CutSet` is the basic building block of PyTorch-style Datasets for speech/audio processing tasks. @@ -242,34 +246,32 @@ class CutSet(Serializable, AlgorithmMixin): - :class:`~lhotse.cut.Cut` """ - def __init__( - self, cuts: Optional[Union[Mapping[str, Cut], ImitatesDict]] = None - ) -> None: - self.cuts = ifnone(cuts, {}) + def __init__(self, cuts: Optional[Iterable[Cut]] = None) -> None: + self.cuts = ifnone(cuts, []) def __eq__(self, other: "CutSet") -> bool: return self.cuts == other.cuts @property - def data(self) -> Union[Dict[str, Cut], Iterable[Cut]]: + def data(self) -> Iterable[Cut]: """Alias property for ``self.cuts``""" return self.cuts @property - def mixed_cuts(self) -> Dict[str, MixedCut]: - return {id_: cut for id_, cut in self.cuts.items() if isinstance(cut, MixedCut)} + def mixed_cuts(self) -> "CutSet": + return CutSet.from_cuts(cut for cut in self.cuts if isinstance(cut, MixedCut)) @property - def simple_cuts(self) -> Dict[str, MonoCut]: - return {id_: cut for id_, cut in self.cuts.items() if isinstance(cut, MonoCut)} + def simple_cuts(self) -> "CutSet": + return CutSet.from_cuts(cut for cut in self.cuts if isinstance(cut, MonoCut)) @property - def multi_cuts(self) -> Dict[str, MultiCut]: - return {id_: cut for id_, cut in self.cuts.items() if isinstance(cut, MultiCut)} + def multi_cuts(self) -> "CutSet": + return CutSet.from_cuts(cut for cut in self.cuts if isinstance(cut, MultiCut)) @property def ids(self) -> Iterable[str]: - return self.cuts.keys() + return (c.id for c in self.cuts) @property def speakers(self) -> FrozenSet[str]: @@ -307,7 +309,8 @@ def from_files( @staticmethod def from_cuts(cuts: Iterable[Cut]) -> "CutSet": - return CutSet(cuts=index_by_id_and_check(cuts)) + """Left for backward compatibility, where it implicitly created an "eager" CutSet.""" + return CutSet(list(cuts)) from_items = from_cuts @@ -827,7 +830,7 @@ def split( :return: A list of :class:`~lhotse.CutSet` pieces. """ return [ - CutSet.from_cuts(subset) + CutSet(subset) for subset in split_sequence( self, num_splits=num_splits, @@ -925,14 +928,14 @@ def subset( cut_ids = list(cut_ids) # Remember the original order id_set = frozenset(cut_ids) # Make a set for quick lookup # Iteration makes it possible to subset lazy manifests - cuts = CutSet.from_cuts(cut for cut in self if cut.id in id_set) + cuts = CutSet([cut for cut in self if cut.id in id_set]) if len(cuts) < len(cut_ids): logging.warning( f"In CutSet.subset(cut_ids=...): expected {len(cut_ids)} cuts but got {len(cuts)} " f"instead ({len(cut_ids) - len(cuts)} cut IDs were not in the CutSet)." ) # Restore the requested cut_ids order. - return CutSet.from_cuts(cuts[cid] for cid in cut_ids) + return cuts.sort_like(cut_ids) def filter_supervisions( self, predicate: Callable[[SupervisionSegment], bool] @@ -1142,7 +1145,7 @@ def trim_to_unsupervised_segments(self) -> "CutSet": ) for span in segments: cuts.append(cut.truncate(offset=span.start, duration=span.duration)) - return CutSet.from_cuts(cuts) + return CutSet(cuts) def trim_to_supervision_groups( self, @@ -1245,7 +1248,7 @@ def sort_by_recording_id(self, ascending: bool = True) -> "CutSet": This is advantageous before caling `save_audios()` on a `trim_to_supervision()` processed `CutSet`, also make sure that `set_caching_enabled(True)` was called. """ - return CutSet.from_cuts( + return CutSet( sorted(self, key=(lambda cut: cut.recording.id), reverse=not ascending) ) @@ -1253,18 +1256,23 @@ def sort_by_duration(self, ascending: bool = False) -> "CutSet": """ Sort the CutSet according to cuts duration and return the result. Descending by default. """ - return CutSet.from_cuts( + return CutSet( sorted(self, key=(lambda cut: cut.duration), reverse=not ascending) ) - def sort_like(self, other: "CutSet") -> "CutSet": + def sort_like(self, other: Union["CutSet", Sequence[str]]) -> "CutSet": """ Sort the CutSet according to the order of cut IDs in ``other`` and return the result. """ + other_ids = list(other.ids if isinstance(other, CutSet) else other) assert set(self.ids) == set( - other.ids + other_ids ), "sort_like() expects both CutSet's to have identical cut IDs." - return CutSet.from_cuts(self[cid] for cid in other.ids) + index_map: Dict[str, int] = {v: index for index, v in enumerate(other_ids)} + ans: List[Cut] = [None] * len(other_ids) + for cut in self: + ans[index_map[cut.id]] = cut + return CutSet(ans) def index_supervisions( self, index_mixed_tracks: bool = False, keep_ids: Optional[Set[str]] = None @@ -1397,7 +1405,7 @@ def compute_offset(): preserve_id=preserve_id, ) ) - return CutSet.from_cuts(truncated_cuts) + return CutSet(truncated_cuts) def extend_by( self, @@ -1513,13 +1521,11 @@ def sample(self, n_cuts: int = 1) -> Union[Cut, "CutSet"]: When ``n_cuts`` is 1, will return a single cut instance; otherwise will return a ``CutSet``. """ assert n_cuts > 0 - # TODO: We might want to make this more efficient in the future - # by holding a cached list of cut ids as a member of CutSet... cut_indices = random.sample(range(len(self)), min(n_cuts, len(self))) cuts = [self[idx] for idx in cut_indices] if n_cuts == 1: return cuts[0] - return CutSet.from_cuts(cuts) + return CutSet(cuts) def resample(self, sampling_rate: int, affix_id: bool = False) -> "CutSet": """ @@ -2194,7 +2200,7 @@ def file_storage_path(cut: Cut, storage_path: Pathlike) -> Path: progress = partial( tqdm, desc="Storing audio recordings", total=len(self) ) - return CutSet.from_cuts( + return CutSet( progress( cut.save_audio( storage_path=file_storage_path(cut, storage_path), @@ -2204,7 +2210,7 @@ def file_storage_path(cut: Cut, storage_path: Pathlike) -> Path: ) for cut in self ) - ) + ).to_eager() # Parallel execution: prepare the CutSet splits cut_sets = self.split(num_jobs, shuffle=shuffle_on_split) @@ -2495,25 +2501,28 @@ def __repr__(self) -> str: len_val = "" return f"CutSet(len={len_val}) [underlying data type: {type(self.data)}]" - def __contains__(self, item: Union[str, Cut]) -> bool: - if isinstance(item, str): - return item in self.cuts + def __contains__(self, other: Union[str, Cut]) -> bool: + if isinstance(other, str): + return any(other == item.id for item in self) else: - return item.id in self.cuts - - def __getitem__(self, cut_id_or_index: Union[int, str]) -> "Cut": - if isinstance(cut_id_or_index, str): - return self.cuts[cut_id_or_index] - # ~100x faster than list(dict.values())[index] for 100k elements - return next( - val for idx, val in enumerate(self.cuts.values()) if idx == cut_id_or_index - ) + return any(other.id == item.id for item in self) + + def __getitem__(self, index_or_id: Union[int, str]) -> Cut: + try: + return self.cuts[index_or_id] # int passed, eager manifest, fast + except TypeError: + # either lazy manifest or str passed, both are slow + if self.is_lazy: + return next(item for idx, item in enumerate(self) if idx == index_or_id) + else: + # string id passed, support just for backward compatibility, not recommended + return next(item for item in self if item.id == index_or_id) def __len__(self) -> int: return len(self.cuts) def __iter__(self) -> Iterable[Cut]: - return iter(self.cuts.values()) + yield from self.cuts def mix( @@ -2993,7 +3002,7 @@ def create_cut_set_eager( else [], ) ) - cuts = CutSet.from_cuts(cuts) + cuts = CutSet(cuts) if output_path is not None: cuts.to_file(output_path) return cuts @@ -3391,7 +3400,7 @@ def _export_to_shar_single( return writer.output_paths -class LazyCutMixer(ImitatesDict): +class LazyCutMixer(Dillable): """ Iterate over cuts from ``cuts`` CutSet while mixing randomly sampled ``mix_in_cuts`` into them. A typical application would be data augmentation with noise, music, babble, etc. diff --git a/lhotse/lazy.py b/lhotse/lazy.py index 3f56a1d67..8a6a470f2 100644 --- a/lhotse/lazy.py +++ b/lhotse/lazy.py @@ -155,9 +155,9 @@ def shuffle( if self.is_lazy: return cls(LazyShuffler(self.data, buffer_size=buffer_size, rng=rng)) else: - ids = list(self.ids) - rng.shuffle(ids) - return cls({id_: self[id_] for id_ in ids}) + new: List = self.data.copy() + rng.shuffle(new) + return cls(new) def repeat(self, times: Optional[int] = None, preserve_id: bool = False): """ @@ -242,26 +242,6 @@ def dill_enabled(value: bool): set_dill_enabled(previous) -class ImitatesDict(Dillable): - """ - Helper base class for lazy iterators defined below. - It exists to make them drop-in replacements for data-holding dicts - in Lhotse's CutSet, RecordingSet, etc. classes. - """ - - def __iter__(self): - raise NotImplemented - - def values(self): - yield from self - - def keys(self): - return (item.id for item in self) - - def items(self): - return ((item.id, item) for item in self) - - class LazyJsonlIterator: """ LazyJsonlIterator provides the ability to read JSON lines as Python dicts. @@ -288,7 +268,7 @@ def __len__(self) -> int: return self._len -class LazyManifestIterator(ImitatesDict): +class LazyManifestIterator(Dillable): """ LazyManifestIterator provides the ability to read Lhotse objects from a JSONL file on-the-fly, without reading its full contents into memory. @@ -317,7 +297,7 @@ def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) -class LazyIteratorChain(ImitatesDict): +class LazyIteratorChain(Dillable): """ A thin wrapper over multiple iterators that enables to combine lazy manifests in Lhotse. It iterates all underlying iterables sequentially. @@ -370,7 +350,7 @@ def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) -class LazyIteratorMultiplexer(ImitatesDict): +class LazyIteratorMultiplexer(Dillable): """ A wrapper over multiple iterators that enables to combine lazy manifests in Lhotse. During iteration, unlike :class:`.LazyIteratorChain`, :class:`.LazyIteratorMultiplexer` @@ -439,7 +419,7 @@ def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) -class LazyInfiniteApproximateMultiplexer(ImitatesDict): +class LazyInfiniteApproximateMultiplexer(Dillable): """ A variant of :class:`.LazyIteratorMultiplexer` that allows to control the number of iterables that are simultaneously open. @@ -560,7 +540,7 @@ def sample_new_stream_at(pos: int) -> None: yield item -class LazyShuffler(ImitatesDict): +class LazyShuffler(Dillable): """ A wrapper over an iterable that enables lazy shuffling. The shuffling algorithm is reservoir-sampling based. @@ -593,7 +573,7 @@ def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) -class LazyFilter(ImitatesDict): +class LazyFilter(Dillable): """ A wrapper over an iterable that enables lazy filtering. It works like Python's `filter` built-in by applying the filter predicate @@ -624,7 +604,7 @@ def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) def __len__(self) -> int: - raise NotImplementedError( + raise TypeError( "LazyFilter does not support __len__ because it would require " "iterating over the whole iterator, which is not possible in a lazy fashion. " "If you really need to know the length, convert to eager mode first using " @@ -632,7 +612,7 @@ def __len__(self) -> int: ) -class LazyMapper(ImitatesDict): +class LazyMapper(Dillable): """ A wrapper over an iterable that enables lazy function evaluation on each item. It works like Python's `map` built-in by applying a callable ``fn`` @@ -664,7 +644,7 @@ def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) -class LazyFlattener(ImitatesDict): +class LazyFlattener(Dillable): """ A wrapper over an iterable of collections that flattens it to an iterable of items. @@ -685,7 +665,7 @@ def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) def __len__(self) -> int: - raise NotImplementedError( + raise TypeError( "LazyFlattener does not support __len__ because it would require " "iterating over the whole iterator, which is not possible in a lazy fashion. " "If you really need to know the length, convert to eager mode first using " @@ -693,7 +673,7 @@ def __len__(self) -> int: ) -class LazyRepeater(ImitatesDict): +class LazyRepeater(Dillable): """ A wrapper over an iterable that enables to repeat it N times or infinitely (default). """ @@ -728,7 +708,7 @@ def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) -class LazySlicer(ImitatesDict): +class LazySlicer(Dillable): """ A wrapper over an iterable that enables selecting k-th element every n elements. """ @@ -750,7 +730,7 @@ def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) def __len__(self) -> int: - raise NotImplementedError( + raise TypeError( "LazySlicer does not support __len__ because it would require " "iterating over the whole iterator, which is not possible in a lazy fashion. " "If you really need to know the length, convert to eager mode first using " diff --git a/lhotse/shar/readers/lazy.py b/lhotse/shar/readers/lazy.py index 306aaa675..c1b0b74a3 100644 --- a/lhotse/shar/readers/lazy.py +++ b/lhotse/shar/readers/lazy.py @@ -1,6 +1,4 @@ -import os import random -import secrets from pathlib import Path from typing import ( Callable, @@ -14,12 +12,10 @@ Union, ) -import torch - from lhotse.cut import Cut -from lhotse.dataset.dataloading import LHOTSE_PROCESS_SEED, resolve_seed +from lhotse.dataset.dataloading import resolve_seed from lhotse.lazy import ( - ImitatesDict, + Dillable, LazyIteratorChain, LazyJsonlIterator, LazyManifestIterator, @@ -30,7 +26,7 @@ from lhotse.utils import Pathlike, exactly_one_not_null, ifnone -class LazySharIterator(ImitatesDict): +class LazySharIterator(Dillable): """ LazySharIterator reads cuts and their corresponding data from multiple shards, also recognized as the Lhotse Shar format. diff --git a/lhotse/supervision.py b/lhotse/supervision.py index a7723cf82..a905cfcab 100644 --- a/lhotse/supervision.py +++ b/lhotse/supervision.py @@ -28,7 +28,6 @@ exactly_one_not_null, fastcopy, ifnone, - index_by_id_and_check, is_equal_or_contains, overspans, perturb_num_samples, @@ -505,12 +504,12 @@ def __delattr__(self, key: str) -> None: class SupervisionSet(Serializable, AlgorithmMixin): """ :class:`~lhotse.supervision.SupervisionSet` represents a collection of segments containing some - supervision information (see :class:`~lhotse.supervision.SupervisionSegment`), - that are indexed by segment IDs. + supervision information (see :class:`~lhotse.supervision.SupervisionSegment`). - It acts as a Python ``dict``, extended with an efficient ``find`` operation that indexes and caches + It acts as a Python ``list``, extended with an efficient ``find`` operation that indexes and caches the supervision segments in an interval tree. It allows to quickly find supervision segments that correspond to a specific time interval. + However, it can also work with lazy iterables. When coming from Kaldi, think of :class:`~lhotse.supervision.SupervisionSet` as a ``segments`` file on steroids, that may also contain *text*, *utt2spk*, *utt2gender*, *utt2dur*, etc. @@ -548,9 +547,7 @@ class SupervisionSet(Serializable, AlgorithmMixin): >>> shuffled = sups.shuffle() """ - def __init__( - self, segments: Optional[Mapping[str, SupervisionSegment]] = None - ) -> None: + def __init__(self, segments: Optional[Iterable[SupervisionSegment]] = None) -> None: self.segments = ifnone(segments, {}) def __eq__(self, other: "SupervisionSet") -> bool: @@ -565,11 +562,11 @@ def data( @property def ids(self) -> Iterable[str]: - return self.segments.keys() + return (s.id for s in self) @staticmethod def from_segments(segments: Iterable[SupervisionSegment]) -> "SupervisionSet": - return SupervisionSet(segments=index_by_id_and_check(segments)) + return SupervisionSet(list(segments)) from_items = from_segments @@ -893,24 +890,25 @@ def _index_by_recording_id_and_cache(self): def __repr__(self) -> str: return f"SupervisionSet(len={len(self)})" - def __getitem__(self, sup_id_or_index: Union[int, str]) -> SupervisionSegment: - if isinstance(sup_id_or_index, str): - return self.segments[sup_id_or_index] - # ~100x faster than list(dict.values())[index] for 100k elements - return next( - val - for idx, val in enumerate(self.segments.values()) - if idx == sup_id_or_index - ) + def __getitem__(self, index_or_id: Union[int, str]) -> SupervisionSegment: + try: + return self.segments[index_or_id] # int passed, eager manifest, fast + except TypeError: + # either lazy manifest or str passed, both are slow + if self.is_lazy: + return next(item for idx, item in enumerate(self) if idx == index_or_id) + else: + # string id passed, support just for backward compatibility, not recommended + return next(item for item in self if item.id == index_or_id) - def __contains__(self, item: Union[str, SupervisionSegment]) -> bool: - if isinstance(item, str): - return item in self.segments + def __contains__(self, other: Union[str, SupervisionSegment]) -> bool: + if isinstance(other, str): + return any(other == item.id for item in self) else: - return item.id in self.segments + return any(other.id == item.id for item in self) def __iter__(self) -> Iterable[SupervisionSegment]: - return iter(self.segments.values()) + yield from self.segments def __len__(self) -> int: return len(self.segments) diff --git a/lhotse/utils.py b/lhotse/utils.py index 3aa993cd8..f303100d3 100644 --- a/lhotse/utils.py +++ b/lhotse/utils.py @@ -704,14 +704,6 @@ def merge_items_with_delimiter( return delimiter.join(chain([prefix], values)) -def index_by_id_and_check(manifests: Iterable[T]) -> Dict[str, T]: - id2man = {} - for m in manifests: - assert m.id not in id2man, f"Duplicated manifest ID: {m.id}" - id2man[m.id] = m - return id2man - - def exactly_one_not_null(*args) -> bool: not_null = [arg is not None for arg in args] return sum(not_null) == 1 diff --git a/lhotse/workflows/meeting_simulation/conversational.py b/lhotse/workflows/meeting_simulation/conversational.py index 899f15e5d..56fe1c3c1 100644 --- a/lhotse/workflows/meeting_simulation/conversational.py +++ b/lhotse/workflows/meeting_simulation/conversational.py @@ -179,7 +179,7 @@ def _create_mixture( ] diff_spk_bernoulli = self.bernoulli.rvs(p=self.prob_diff_spk_overlap, size=N) - utterances = list(utterances.data.values()) + utterances = list(utterances) # First sample offsets for each utterance. These are w.r.t. start of the meeting. # For each subsequent utterance, we sample a pause or overlap time from the # corresponding distribution. Then, we add the pause/overlap time to the offset diff --git a/lhotse/workflows/meeting_simulation/speaker_independent.py b/lhotse/workflows/meeting_simulation/speaker_independent.py index 83ecd784a..edb0ac97a 100644 --- a/lhotse/workflows/meeting_simulation/speaker_independent.py +++ b/lhotse/workflows/meeting_simulation/speaker_independent.py @@ -104,7 +104,7 @@ def _create_mixture( zip(utterances, silence_durations) ): # Get list of cuts from CutSet - spk_utterances = list(spk_utterances.data.values()) + spk_utterances = list(spk_utterances) track = spk_utterances[0] for sil, utt in zip(spk_silences[1:], spk_utterances[1:]): track = mix(track, utt, offset=track.duration + sil, allow_padding=True) diff --git a/test/cut/test_cut.py b/test/cut/test_cut.py index e177eba89..6a17f1f4c 100644 --- a/test/cut/test_cut.py +++ b/test/cut/test_cut.py @@ -38,14 +38,14 @@ def libri_cut(libri_cut_set) -> MonoCut: def test_load_none_feats_cut_set(): cutset = CutSet.from_json("test/fixtures/libri/cuts_no_feats.json") - cut = list(cutset.cuts.values())[0] + cut = cutset[0] assert cut.features is None assert cut.recording is not None def test_load_none_recording_cut_set(): cutset = CutSet.from_json("test/fixtures/libri/cuts_no_recording.json") - cut = list(cutset.cuts.values())[0] + cut = cutset[0] assert cut.recording is None assert cut.features is not None @@ -185,7 +185,7 @@ def test_make_cuts_from_recordings(dummy_recording_set): assert len(cut1.supervisions) == 0 assert cut1.has_recording - assert cut1.recording == dummy_recording_set.recordings["rec1"] + assert cut1.recording == dummy_recording_set["rec1"] assert cut1.sampling_rate == 16000 assert cut1.recording_id == "rec1" assert cut1.num_samples == 160000 @@ -215,7 +215,7 @@ def test_make_cuts_from_features(dummy_feature_set): assert cut1.num_samples is None assert cut1.has_features - assert cut1.features == dummy_feature_set.features[0] + assert cut1.features == dummy_feature_set[0] assert cut1.frame_shift == 0.01 assert cut1.num_frames == 1000 assert cut1.num_features == 23 @@ -235,13 +235,13 @@ def test_make_cuts_from_features_recordings(dummy_recording_set, dummy_feature_s assert len(cut1.supervisions) == 0 assert cut1.has_recording - assert cut1.recording == dummy_recording_set.recordings["rec1"] + assert cut1.recording == dummy_recording_set["rec1"] assert cut1.sampling_rate == 16000 assert cut1.recording_id == "rec1" assert cut1.num_samples == 160000 assert cut1.has_features - assert cut1.features == dummy_feature_set.features[0] + assert cut1.features == dummy_feature_set[0] assert cut1.frame_shift == 0.01 assert cut1.num_frames == 1000 assert cut1.num_features == 23 @@ -295,7 +295,7 @@ def test_make_cuts_from_recordings_supervisions( assert cut1.supervisions[0].text == "dummy text" assert cut1.has_recording - assert cut1.recording == dummy_recording_set.recordings["rec1"] + assert cut1.recording == dummy_recording_set["rec1"] assert cut1.sampling_rate == 16000 assert cut1.recording_id == "rec1" assert cut1.num_samples == 16000 * 4 @@ -335,7 +335,7 @@ def test_make_cuts_from_features_supervisions( assert cut1.num_samples is None assert cut1.has_features - assert cut1.features == dummy_feature_set.features[0] + assert cut1.features == dummy_feature_set[0] assert cut1.frame_shift == 0.01 assert cut1.num_frames == 400 assert cut1.num_features == 23 @@ -365,13 +365,13 @@ def test_make_cuts_from_recordings_features_supervisions( assert cut1.supervisions[0].text == "dummy text" assert cut1.has_recording - assert cut1.recording == dummy_recording_set.recordings["rec1"] + assert cut1.recording == dummy_recording_set["rec1"] assert cut1.sampling_rate == 16000 assert cut1.recording_id == "rec1" assert cut1.num_samples == 16000 * 4 assert cut1.has_features - assert cut1.features == dummy_feature_set.features[0] + assert cut1.features == dummy_feature_set[0] assert cut1.frame_shift == 0.01 assert cut1.num_frames == 400 assert cut1.num_features == 23 @@ -400,7 +400,7 @@ def test_make_cuts_from_recordings_supervisions( assert cut1.supervisions[0].text == "dummy text" assert cut1.has_recording - assert cut1.recording == dummy_recording_set.recordings["rec1"] + assert cut1.recording == dummy_recording_set["rec1"] assert cut1.sampling_rate == 16000 assert cut1.recording_id == "rec1" assert cut1.num_samples == 160000 @@ -439,7 +439,7 @@ def test_make_cuts_from_features_supervisions( assert cut1.num_samples is None assert cut1.has_features - assert cut1.features == dummy_feature_set.features[0] + assert cut1.features == dummy_feature_set[0] assert cut1.frame_shift == 0.01 assert cut1.num_frames == 1000 assert cut1.num_features == 23 @@ -468,13 +468,13 @@ def test_make_cuts_from_recordings_features_supervisions( assert cut1.supervisions[0].text == "dummy text" assert cut1.has_recording - assert cut1.recording == dummy_recording_set.recordings["rec1"] + assert cut1.recording == dummy_recording_set["rec1"] assert cut1.sampling_rate == 16000 assert cut1.recording_id == "rec1" assert cut1.num_samples == 160000 assert cut1.has_features - assert cut1.features == dummy_feature_set.features[0] + assert cut1.features == dummy_feature_set[0] assert cut1.frame_shift == 0.01 assert cut1.num_frames == 1000 assert cut1.num_features == 23 diff --git a/test/cut/test_cut_set.py b/test/cut/test_cut_set.py index 395848055..209978d6e 100644 --- a/test/cut/test_cut_set.py +++ b/test/cut/test_cut_set.py @@ -48,7 +48,7 @@ def cut_set_with_mixed_cut(cut1, cut2): id="mixed-cut-id", tracks=[MixTrack(cut=cut1), MixTrack(cut=cut2, offset=1.0, snr=10)], ) - return CutSet({cut.id: cut for cut in [cut1, cut2, mixed_cut]}) + return CutSet([cut1, cut2, mixed_cut]) @pytest.mark.parametrize( @@ -82,10 +82,10 @@ def test_cut_set_iteration(cut_set_with_mixed_cut): def test_cut_set_holds_both_simple_and_mixed_cuts(cut_set_with_mixed_cut): - simple_cuts = cut_set_with_mixed_cut.simple_cuts.values() + simple_cuts = cut_set_with_mixed_cut.simple_cuts assert all(isinstance(c, MonoCut) for c in simple_cuts) assert len(simple_cuts) == 2 - mixed_cuts = cut_set_with_mixed_cut.mixed_cuts.values() + mixed_cuts = cut_set_with_mixed_cut.mixed_cuts assert all(isinstance(c, MixedCut) for c in mixed_cuts) assert len(mixed_cuts) == 1 @@ -725,3 +725,10 @@ def test_cut_set_from_files(): assert cs[0].id == "dummy-mono-cut-0000" # On second iteration, we see a different order assert cs[0].id == "dummy-mono-cut-0010" + + +def test_cut_set_duplicate_ids_allowed(): + cut = dummy_cut(0) + cuts = CutSet.from_cuts([cut, cut]) + assert len(cuts) == 2 + assert cuts[0].id == cuts[1].id diff --git a/test/test_lazy.py b/test/test_lazy.py index 38cfef579..636612caf 100644 --- a/test/test_lazy.py +++ b/test/test_lazy.py @@ -103,9 +103,9 @@ def predicate(item): with as_lazy(data) as lazy_data: lazy_result = lazy_data.filter(predicate) - with pytest.raises(NotImplementedError): - assert list(lazy_result) == list(expected) - assert list(lazy_result.to_eager()) == list(expected) + with pytest.raises(TypeError): + len(lazy_result) + assert list(lazy_result) == list(expected) @pytest.mark.parametrize(