From 1cd2d88b18672eb0f720f990a3381a982d20b7f4 Mon Sep 17 00:00:00 2001 From: Hyeonseung Lee Date: Fri, 1 Mar 2024 01:13:54 +0900 Subject: [PATCH] Cutconcat fixed max duration (#1292) * CutConcatenate: added fixed max_duration parameter * pre-commit applied --- lhotse/dataset/cut_transforms/concatenate.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/lhotse/dataset/cut_transforms/concatenate.py b/lhotse/dataset/cut_transforms/concatenate.py index 977c92884..2f5955153 100644 --- a/lhotse/dataset/cut_transforms/concatenate.py +++ b/lhotse/dataset/cut_transforms/concatenate.py @@ -12,7 +12,12 @@ class CutConcatenate: adding some silence between them to avoid a large number of padding frames that waste the computation. """ - def __init__(self, gap: Seconds = 1.0, duration_factor: float = 1.0) -> None: + def __init__( + self, + gap: Seconds = 1.0, + duration_factor: float = 1.0, + max_duration: Optional[Seconds] = None, + ) -> None: """ CutConcatenate's constructor. @@ -20,14 +25,21 @@ def __init__(self, gap: Seconds = 1.0, duration_factor: float = 1.0) -> None: it's goal is to let the model "know" that there are separate utterances in a single example. :param duration_factor: Determines the maximum duration of the concatenated cuts; by default it's 1, setting the limit at the duration of the longest cut in the batch. + :param max_duration: If a value is given (in seconds), the maximum duration of concatenated cuts + is fixed to the value while duration_factor is ignored. """ self.gap = gap self.duration_factor = duration_factor + self.max_duration = max_duration def __call__(self, cuts: CutSet) -> CutSet: cuts = cuts.sort_by_duration(ascending=False) return concat_cuts( - cuts, gap=self.gap, max_duration=cuts[0].duration * self.duration_factor + cuts, + gap=self.gap, + max_duration=self.max_duration + if self.max_duration + else cuts[0].duration * self.duration_factor, )