From 774ac439a49b4f0d4eef537379115fafe3b74cbb Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Fri, 12 Jan 2024 23:25:00 +0800 Subject: [PATCH] support whisper large v3; deepspeed launcher rank world_size setting (#1260) --- lhotse/dataset/sampling/base.py | 6 ++++++ lhotse/features/whisper_fbank.py | 5 ++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/lhotse/dataset/sampling/base.py b/lhotse/dataset/sampling/base.py index 393c48e7a..1b5943a20 100644 --- a/lhotse/dataset/sampling/base.py +++ b/lhotse/dataset/sampling/base.py @@ -1,3 +1,4 @@ +import os import warnings from copy import deepcopy from dataclasses import asdict, dataclass @@ -100,6 +101,11 @@ def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int] assert world_size >= 1 if rank is not None: assert rank >= 0 + if "WORLD_SIZE" in os.environ and "RANK" in os.environ: + # If deepspeed launcher is being used, it will set the env variables automatically. + self.world_size = int(os.environ["WORLD_SIZE"]) + self.rank = int(os.environ["RANK"]) + return if not dist.is_available() or not dist.is_initialized(): self.world_size = 1 if world_size is None else world_size self.rank = 0 if rank is None else rank diff --git a/lhotse/features/whisper_fbank.py b/lhotse/features/whisper_fbank.py index d112e995b..df69071ff 100644 --- a/lhotse/features/whisper_fbank.py +++ b/lhotse/features/whisper_fbank.py @@ -16,6 +16,7 @@ def log_mel_spectrogram( audio: Union[np.ndarray, torch.Tensor], + # large-v3 using 128 filters, others use 80 n_mels: int = 80, n_fft: int = 400, hop_length: int = 160, @@ -84,6 +85,7 @@ def log_mel_spectrogram( @dataclass class WhisperFbankConfig: + num_filters: int = 80 device: str = "cpu" def to_dict(self) -> Dict[str, Any]: @@ -102,8 +104,8 @@ class WhisperFbank(FeatureExtractor): def __init__(self, config: Optional[WhisperFbankConfig] = None): super().__init__(config=config) self.sampling_rate = 16000 - self.num_filters = 80 self.hop_length = 160 + self.num_filters = self.config.num_filters @property def device(self) -> Union[str, torch.device]: @@ -136,6 +138,7 @@ def extract( feats = log_mel_spectrogram( samples, + n_mels=self.num_filters, device=self.device, )