Skip to content

Commit

Permalink
Fixes for manifest validation and fixing (#1284)
Browse files Browse the repository at this point in the history
  • Loading branch information
pzelasko committed Feb 11, 2024
1 parent 00abc09 commit 5321577
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 17 deletions.
16 changes: 12 additions & 4 deletions lhotse/bin/modes/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ def validate_(manifest: Pathlike, read_data: bool):
from lhotse import load_manifest, validate

data = load_manifest(manifest)
validate(data, read_data=read_data)
try:
validate(data, read_data=read_data)
except AssertionError as e:
click.echo(f"Validation failed: {e}")
return 1


@cli.command(name="validate-pair")
Expand All @@ -40,9 +44,13 @@ def validate_(recordings: Pathlike, supervisions: Pathlike, read_data: bool):

recs = load_manifest(recordings)
sups = load_manifest(supervisions)
validate_recordings_and_supervisions(
recordings=recs, supervisions=sups, read_data=read_data
)
try:
validate_recordings_and_supervisions(
recordings=recs, supervisions=sups, read_data=read_data
)
except AssertionError as e:
click.echo(f"Validation failed: {e}")
return 1


@cli.command(name="fix")
Expand Down
40 changes: 27 additions & 13 deletions lhotse/qa.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from collections import defaultdict
from collections import Counter, defaultdict
from math import isclose
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union

Expand Down Expand Up @@ -88,20 +88,22 @@ def validate_recordings_and_supervisions(
These items will be discarded by default when creating a CutSet.
"""
if isinstance(recordings, Recording):
recordings = RecordingSet.from_recordings([recordings])
recordings = RecordingSet([recordings])
if isinstance(supervisions, SupervisionSegment):
supervisions = SupervisionSet.from_segments([supervisions])
supervisions = SupervisionSet([supervisions])

if recordings.is_lazy:
recordings = RecordingSet.from_recordings(iter(recordings))
if supervisions.is_lazy:
supervisions = SupervisionSet.from_segments(iter(supervisions))
recordings = recordings.to_eager()
supervisions = supervisions.to_eager()

validate(recordings, read_data=read_data)
validate(supervisions)
# Errors
id2rec = {r.id: r for r in recordings}
for s in supervisions:
r = recordings[s.recording_id]
r = id2rec.get(s.recording_id)
assert (
r is not None
), f"Supervision {s.id} references non-existent recording {s.recording_id}"
assert -1e-3 <= s.start <= s.end <= r.duration + 1e-3, (
f"Supervision {s.id}: exceeded the bounds of its corresponding recording "
f"(supervision spans [{s.start}, {s.end}]; recording spans [0, {r.duration}])"
Expand All @@ -111,7 +113,7 @@ def validate_recordings_and_supervisions(
f"(recording channels: {r.channel_ids})"
)
# Warnings
recording_ids = frozenset(r.id for r in recordings)
recording_ids = id2rec.keys()
recording_ids_in_sups = frozenset(s.recording_id for s in supervisions)
only_in_recordings = recording_ids - recording_ids_in_sups
if only_in_recordings:
Expand Down Expand Up @@ -172,15 +174,14 @@ def trim_supervisions_to_recordings(
not exceeding the duration of their corresponding :class:`~lhotse.audio.Recording`.
"""
if isinstance(recordings, Recording):
recordings = RecordingSet.from_recordings([recordings])
if recordings.is_lazy:
recordings = RecordingSet.from_recordings(iter(recordings))
recordings = RecordingSet([recordings])

id2rec = {r.id: r for r in recordings}
sups = []
removed = 0
trimmed = 0
for s in supervisions:
end = recordings[s.recording_id].duration
end = id2rec[s.recording_id].duration
if s.start > end:
removed += 1
continue
Expand Down Expand Up @@ -438,20 +439,30 @@ def validate_cut(c: Cut, read_data: bool = False) -> None:
@register_validator
def validate_recording_set(recordings: RecordingSet, read_data: bool = False) -> None:
rates = set()
ids = Counter()
for r in recordings:
validate_recording(r, read_data=read_data)
rates.add(r.sampling_rate)
ids[r.id] += 1
if len(rates) > 1:
logging.warning(
f"RecordingSet contains recordings with different sampling rates ({rates}). "
f"Make sure that this was intended."
)
assert (
ids.most_common(1)[0][1] <= 1
), "RecordingSet has recordings with duplicated IDs."


@register_validator
def validate_supervision_set(supervisions: SupervisionSet, **kwargs) -> None:
ids = Counter()
for s in supervisions:
validate_supervision(s)
ids[s.id] += 1
assert (
ids.most_common(1)[0][1] <= 1
), "SupervisionSet has supervisions with duplicated IDs."

# Catch errors in data preparation:
# - more than one supervision for a given recording starts at 0 (in a given channel)
Expand Down Expand Up @@ -494,5 +505,8 @@ def validate_feature_set(features: FeatureSet, read_data: bool = False) -> None:

@register_validator
def validate_cut_set(cuts: CutSet, read_data: bool = False) -> None:
ids = Counter()
for c in cuts:
validate_cut(c, read_data=read_data)
ids[c.id] += 1
assert ids.most_common(1)[0][1] <= 1, "CutSet has cuts with duplicated IDs."

0 comments on commit 5321577

Please sign in to comment.