From 53215777fdb26e3ba8df9ed43ec8677b751b5e5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sun, 11 Feb 2024 18:35:52 -0500 Subject: [PATCH] Fixes for manifest validation and fixing (#1284) --- lhotse/bin/modes/validate.py | 16 +++++++++++---- lhotse/qa.py | 40 ++++++++++++++++++++++++------------ 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/lhotse/bin/modes/validate.py b/lhotse/bin/modes/validate.py index e6f831a0a..650e098ba 100644 --- a/lhotse/bin/modes/validate.py +++ b/lhotse/bin/modes/validate.py @@ -19,7 +19,11 @@ def validate_(manifest: Pathlike, read_data: bool): from lhotse import load_manifest, validate data = load_manifest(manifest) - validate(data, read_data=read_data) + try: + validate(data, read_data=read_data) + except AssertionError as e: + click.echo(f"Validation failed: {e}") + return 1 @cli.command(name="validate-pair") @@ -40,9 +44,13 @@ def validate_(recordings: Pathlike, supervisions: Pathlike, read_data: bool): recs = load_manifest(recordings) sups = load_manifest(supervisions) - validate_recordings_and_supervisions( - recordings=recs, supervisions=sups, read_data=read_data - ) + try: + validate_recordings_and_supervisions( + recordings=recs, supervisions=sups, read_data=read_data + ) + except AssertionError as e: + click.echo(f"Validation failed: {e}") + return 1 @cli.command(name="fix") diff --git a/lhotse/qa.py b/lhotse/qa.py index b1d20ce73..9b7c07630 100644 --- a/lhotse/qa.py +++ b/lhotse/qa.py @@ -1,5 +1,5 @@ import logging -from collections import defaultdict +from collections import Counter, defaultdict from math import isclose from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union @@ -88,20 +88,22 @@ def validate_recordings_and_supervisions( These items will be discarded by default when creating a CutSet. """ if isinstance(recordings, Recording): - recordings = RecordingSet.from_recordings([recordings]) + recordings = RecordingSet([recordings]) if isinstance(supervisions, SupervisionSegment): - supervisions = SupervisionSet.from_segments([supervisions]) + supervisions = SupervisionSet([supervisions]) - if recordings.is_lazy: - recordings = RecordingSet.from_recordings(iter(recordings)) - if supervisions.is_lazy: - supervisions = SupervisionSet.from_segments(iter(supervisions)) + recordings = recordings.to_eager() + supervisions = supervisions.to_eager() validate(recordings, read_data=read_data) validate(supervisions) # Errors + id2rec = {r.id: r for r in recordings} for s in supervisions: - r = recordings[s.recording_id] + r = id2rec.get(s.recording_id) + assert ( + r is not None + ), f"Supervision {s.id} references non-existent recording {s.recording_id}" assert -1e-3 <= s.start <= s.end <= r.duration + 1e-3, ( f"Supervision {s.id}: exceeded the bounds of its corresponding recording " f"(supervision spans [{s.start}, {s.end}]; recording spans [0, {r.duration}])" @@ -111,7 +113,7 @@ def validate_recordings_and_supervisions( f"(recording channels: {r.channel_ids})" ) # Warnings - recording_ids = frozenset(r.id for r in recordings) + recording_ids = id2rec.keys() recording_ids_in_sups = frozenset(s.recording_id for s in supervisions) only_in_recordings = recording_ids - recording_ids_in_sups if only_in_recordings: @@ -172,15 +174,14 @@ def trim_supervisions_to_recordings( not exceeding the duration of their corresponding :class:`~lhotse.audio.Recording`. """ if isinstance(recordings, Recording): - recordings = RecordingSet.from_recordings([recordings]) - if recordings.is_lazy: - recordings = RecordingSet.from_recordings(iter(recordings)) + recordings = RecordingSet([recordings]) + id2rec = {r.id: r for r in recordings} sups = [] removed = 0 trimmed = 0 for s in supervisions: - end = recordings[s.recording_id].duration + end = id2rec[s.recording_id].duration if s.start > end: removed += 1 continue @@ -438,20 +439,30 @@ def validate_cut(c: Cut, read_data: bool = False) -> None: @register_validator def validate_recording_set(recordings: RecordingSet, read_data: bool = False) -> None: rates = set() + ids = Counter() for r in recordings: validate_recording(r, read_data=read_data) rates.add(r.sampling_rate) + ids[r.id] += 1 if len(rates) > 1: logging.warning( f"RecordingSet contains recordings with different sampling rates ({rates}). " f"Make sure that this was intended." ) + assert ( + ids.most_common(1)[0][1] <= 1 + ), "RecordingSet has recordings with duplicated IDs." @register_validator def validate_supervision_set(supervisions: SupervisionSet, **kwargs) -> None: + ids = Counter() for s in supervisions: validate_supervision(s) + ids[s.id] += 1 + assert ( + ids.most_common(1)[0][1] <= 1 + ), "SupervisionSet has supervisions with duplicated IDs." # Catch errors in data preparation: # - more than one supervision for a given recording starts at 0 (in a given channel) @@ -494,5 +505,8 @@ def validate_feature_set(features: FeatureSet, read_data: bool = False) -> None: @register_validator def validate_cut_set(cuts: CutSet, read_data: bool = False) -> None: + ids = Counter() for c in cuts: validate_cut(c, read_data=read_data) + ids[c.id] += 1 + assert ids.most_common(1)[0][1] <= 1, "CutSet has cuts with duplicated IDs."