diff --git a/lhotse/dataset/cut_transforms/concatenate.py b/lhotse/dataset/cut_transforms/concatenate.py index 977c92884..2f5955153 100644 --- a/lhotse/dataset/cut_transforms/concatenate.py +++ b/lhotse/dataset/cut_transforms/concatenate.py @@ -12,7 +12,12 @@ class CutConcatenate: adding some silence between them to avoid a large number of padding frames that waste the computation. """ - def __init__(self, gap: Seconds = 1.0, duration_factor: float = 1.0) -> None: + def __init__( + self, + gap: Seconds = 1.0, + duration_factor: float = 1.0, + max_duration: Optional[Seconds] = None, + ) -> None: """ CutConcatenate's constructor. @@ -20,14 +25,21 @@ def __init__(self, gap: Seconds = 1.0, duration_factor: float = 1.0) -> None: it's goal is to let the model "know" that there are separate utterances in a single example. :param duration_factor: Determines the maximum duration of the concatenated cuts; by default it's 1, setting the limit at the duration of the longest cut in the batch. + :param max_duration: If a value is given (in seconds), the maximum duration of concatenated cuts + is fixed to the value while duration_factor is ignored. """ self.gap = gap self.duration_factor = duration_factor + self.max_duration = max_duration def __call__(self, cuts: CutSet) -> CutSet: cuts = cuts.sort_by_duration(ascending=False) return concat_cuts( - cuts, gap=self.gap, max_duration=cuts[0].duration * self.duration_factor + cuts, + gap=self.gap, + max_duration=self.max_duration + if self.max_duration + else cuts[0].duration * self.duration_factor, )