Skip to content

Commit

Permalink
Merge branch 'pa-speaker-detector-robustness' into 'main'
Browse files Browse the repository at this point in the history
ENH: Improve robustness of PASpeakerDetector

See merge request heka/medkit!245

changelog: ENH: Improve robustness of PASpeakerDetector
  • Loading branch information
ghisvail committed Dec 13, 2023
2 parents 25ed286 + 6f05446 commit f4b0687
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 12 deletions.
13 changes: 11 additions & 2 deletions medkit/audio/segmentation/pa_speaker_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
from medkit.core.audio import Segment, SegmentationOperation, Span


# margin (in seconds) by which a turn segment
# may overrun the input segment due to imprecision
_DURATION_MARGIN = 0.1


class PASpeakerDetector(SegmentationOperation):
"""Speaker diarization operation relying on `pyannote.audio`
Expand Down Expand Up @@ -150,12 +155,16 @@ def _detect_turns_in_segment(self, segment: Segment) -> Iterator[Segment]:
for turn, _, speaker in diarization.itertracks(yield_label=True):
if turn.duration < self.min_duration:
continue

# trim original audio to turn start/end points
turn_audio = audio.trim_duration(turn.start, turn.end)
# (allow pyannote's turn to be slighty over the total input duration)
assert turn.end < audio.duration + _DURATION_MARGIN
turn_end = min(turn.end, audio.duration)
turn_audio = audio.trim_duration(turn.start, turn_end)

turn_span = Span(
start=segment.span.start + turn.start,
end=segment.span.start + turn.end,
end=segment.span.start + turn_end,
)
speaker_attr = Attribute(label="speaker", value=speaker)
turn_segment = Segment(
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/audio/metrics/test_transcription_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ def _get_doc():
turn_seg_1 = Segment(
label="speech",
audio=_FULL_AUDIO.trim_duration(start_time=0.0, end_time=2.0),
span=Span(start=0.0, end=4.0),
span=Span(start=0.0, end=2.0),
attrs=[Attribute(label="transcription", value="Bonjour ça va bien ?")],
)
doc.anns.add(turn_seg_1)

turn_seg_2 = Segment(
label="speech",
audio=_FULL_AUDIO.trim_duration(start_time=2.0, end_time=4.0),
span=Span(5.0, 6.0),
span=Span(2.0, 4.0),
attrs=[Attribute(label="transcription", value="Ça va et vous ?")],
)
doc.anns.add(turn_seg_2)
Expand All @@ -45,7 +45,7 @@ def _get_doc():
"identical": (
[
{"start": 0.0, "end": 2.0, "transcription": "Bonjour ça va bien ?"},
{"start": 2.0, "end": 4.0, "transcription": "Ça va et vous ?"},
{"start": 3.0, "end": 4.0, "transcription": "Ça va et vous ?"},
],
{},
TranscriptionEvaluatorResult(
Expand All @@ -65,7 +65,7 @@ def _get_doc():
"errors": (
[
{"start": 0.0, "end": 2.0, "transcription": "Bonjour ça va ?"},
{"start": 2.0, "end": 4.0, "transcription": "Bien et vous ?"},
{"start": 3.0, "end": 4.0, "transcription": "Bien et vous ?"},
],
{},
TranscriptionEvaluatorResult(
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/audio/preprocessing/test_downmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@


def _get_segment(signal):
duration = signal.shape[1] * _SAMPLE_RATE
duration = signal.shape[1] / _SAMPLE_RATE
audio = MemoryAudioBuffer(signal=signal, sample_rate=_SAMPLE_RATE)
return Segment(label="raw", span=Span(_SPAN_OFFSET, duration), audio=audio)
return Segment(
label="raw", span=Span(_SPAN_OFFSET, _SPAN_OFFSET + duration), audio=audio
)


def _check_downmixed_segment(downmixed_seg, original_seg):
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/audio/preprocessing/test_power_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@


def _get_segment(signal):
duration = signal.shape[1] * _SAMPLE_RATE
duration = signal.shape[1] / _SAMPLE_RATE
audio = MemoryAudioBuffer(signal=signal, sample_rate=_SAMPLE_RATE)
return Segment(label="raw", span=Span(_SPAN_OFFSET, duration), audio=audio)
return Segment(
label="raw", span=Span(_SPAN_OFFSET, _SPAN_OFFSET + duration), audio=audio
)


def _check_normalized_segment(normalized_seg, original_seg):
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/audio/preprocessing/test_resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@


def _get_segment(signal, sample_rate):
duration = signal.shape[1] * sample_rate
duration = signal.shape[1] / sample_rate
audio = MemoryAudioBuffer(signal=signal, sample_rate=sample_rate)
return Segment(label="raw", span=Span(_SPAN_OFFSET, duration), audio=audio)
return Segment(
label="raw", span=Span(_SPAN_OFFSET, _SPAN_OFFSET + duration), audio=audio
)


def _check_resampled_segment(resampled_seg, original_seg):
Expand Down

0 comments on commit f4b0687

Please sign in to comment.