diff --git a/espnet2/asr/specaug/specaug.py b/espnet2/asr/specaug/specaug.py index f35d1468585..8dde0f53b1f 100644 --- a/espnet2/asr/specaug/specaug.py +++ b/espnet2/asr/specaug/specaug.py @@ -33,6 +33,7 @@ def __init__( time_mask_width_range: Optional[Union[int, Sequence[int]]] = None, time_mask_width_ratio_range: Optional[Union[float, Sequence[float]]] = None, num_time_mask: int = 2, + replace_with_zero: bool = True, ): if not apply_time_warp and not apply_time_mask and not apply_freq_mask: raise ValueError( @@ -62,6 +63,7 @@ def __init__( dim="freq", mask_width_range=freq_mask_width_range, num_mask=num_freq_mask, + replace_with_zero=replace_with_zero, ) else: self.freq_mask = None @@ -72,12 +74,14 @@ def __init__( dim="time", mask_width_range=time_mask_width_range, num_mask=num_time_mask, + replace_with_zero=replace_with_zero, ) elif time_mask_width_ratio_range is not None: self.time_mask = MaskAlongAxisVariableMaxWidth( dim="time", mask_width_ratio_range=time_mask_width_ratio_range, num_mask=num_time_mask, + replace_with_zero=replace_with_zero, ) else: raise ValueError(