Skip to content

Commit

Permalink
support whisper large v3; deepspeed launcher rank world_size setting (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
yuekaizhang committed Jan 12, 2024
1 parent b3373c0 commit 774ac43
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
6 changes: 6 additions & 0 deletions lhotse/dataset/sampling/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import warnings
from copy import deepcopy
from dataclasses import asdict, dataclass
Expand Down Expand Up @@ -100,6 +101,11 @@ def _maybe_init_distributed(self, world_size: Optional[int], rank: Optional[int]
assert world_size >= 1
if rank is not None:
assert rank >= 0
if "WORLD_SIZE" in os.environ and "RANK" in os.environ:
# If deepspeed launcher is being used, it will set the env variables automatically.
self.world_size = int(os.environ["WORLD_SIZE"])
self.rank = int(os.environ["RANK"])
return
if not dist.is_available() or not dist.is_initialized():
self.world_size = 1 if world_size is None else world_size
self.rank = 0 if rank is None else rank
Expand Down
5 changes: 4 additions & 1 deletion lhotse/features/whisper_fbank.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

def log_mel_spectrogram(
audio: Union[np.ndarray, torch.Tensor],
# large-v3 using 128 filters, others use 80
n_mels: int = 80,
n_fft: int = 400,
hop_length: int = 160,
Expand Down Expand Up @@ -84,6 +85,7 @@ def log_mel_spectrogram(

@dataclass
class WhisperFbankConfig:
num_filters: int = 80
device: str = "cpu"

def to_dict(self) -> Dict[str, Any]:
Expand All @@ -102,8 +104,8 @@ class WhisperFbank(FeatureExtractor):
def __init__(self, config: Optional[WhisperFbankConfig] = None):
super().__init__(config=config)
self.sampling_rate = 16000
self.num_filters = 80
self.hop_length = 160
self.num_filters = self.config.num_filters

@property
def device(self) -> Union[str, torch.device]:
Expand Down Expand Up @@ -136,6 +138,7 @@ def extract(

feats = log_mel_spectrogram(
samples,
n_mels=self.num_filters,
device=self.device,
)

Expand Down

0 comments on commit 774ac43

Please sign in to comment.