Skip to content

Commit

Permalink
Add whisper feature extractor (#1159)
Browse files Browse the repository at this point in the history
* add whisper feature extractor

* hard coding config

* add test for whisper feature extractor
  • Loading branch information
yuekaizhang committed Sep 22, 2023
1 parent 7b60f86 commit 3875788
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 0 deletions.
1 change: 1 addition & 0 deletions lhotse/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@
from .opensmile import OpenSmileConfig, OpenSmileExtractor
from .spectrogram import TorchaudioSpectrogram, TorchaudioSpectrogramConfig
from .ssl import S3PRLSSL, S3PRLSSLConfig
from .whisper_fbank import WhisperFbank, WhisperFbankConfig
161 changes: 161 additions & 0 deletions lhotse/features/whisper_fbank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union

import numpy as np
import torch

from lhotse.features.base import FeatureExtractor, register_extractor
from lhotse.utils import (
EPSILON,
Seconds,
asdict_nonull,
compute_num_frames_from_samples,
is_module_available,
)


def log_mel_spectrogram(
audio: Union[np.ndarray, torch.Tensor],
n_mels: int = 80,
n_fft: int = 400,
hop_length: int = 160,
sampling_rate: int = 16000,
device: Optional[Union[str, torch.device]] = None,
):
"""
From https://github.com/openai/whisper/blob/main/whisper/audio.py
Compute the log-Mel spectrogram of
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
The number of Mel-frequency filters, only 80 is supported
device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT
Returns
-------
torch.Tensor, shape = (n_frames, 80)
A Tensor that contains the Mel spectrogram
"""
if is_module_available("librosa"):
import librosa
else:
raise ImportError(
"Librosa is not installed. Please install librosa before using LibrosaFbank extractor."
)
if not torch.is_tensor(audio):
audio = torch.from_numpy(audio)

if device is not None:
audio = audio.to(device)
audio = audio.squeeze(0)
window = torch.hann_window(n_fft).to(audio.device)
stft = torch.stft(audio, n_fft, hop_length, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2

filters = librosa.filters.mel(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels)
filters = torch.from_numpy(filters).to(device)
mel_spec = filters @ magnitudes

log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0

padding = compute_num_frames_from_samples(
num_samples=len(audio),
frame_shift=hop_length / sampling_rate,
sampling_rate=sampling_rate,
)
if padding > log_spec.shape[1]:
log_spec = torch.nn.functional.pad(
log_spec, (0, padding - log_spec.shape[1]), mode="constant"
)
# change shape from 80, n_frames to n_frames,80
log_spec = log_spec.transpose(0, 1)

return log_spec


@dataclass
class WhisperFbankConfig:
device: str = "cpu"

def to_dict(self) -> Dict[str, Any]:
return asdict_nonull(self)

@staticmethod
def from_dict(data: Dict[str, Any]) -> "WhisperFbankConfig":
return WhisperFbankConfig(**data)


@register_extractor
class WhisperFbank(FeatureExtractor):
name = "whisper-fbank"
config_type = WhisperFbankConfig

def __init__(self, config: Optional[WhisperFbankConfig] = None):
super().__init__(config=config)
self.sampling_rate = 16000
self.num_filters = 80
self.hop_length = 160

@property
def device(self) -> Union[str, torch.device]:
return self.config.device

@property
def frame_shift(self) -> Seconds:
return self.hop_length / self.sampling_rate

def to(self, device: str):
self.config.device = device

def feature_dim(self, sampling_rate: int) -> int:
return self.num_filters

def extract(
self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
) -> Union[np.ndarray, torch.Tensor]:
assert sampling_rate == self.sampling_rate, (
f"Fbank was instantiated for sampling_rate "
f"{self.sampling_rate}, but "
f"sampling_rate={sampling_rate} was passed to extract(). "
"Note you can use CutSet/RecordingSet.resample() to change the audio sampling rate."
)

is_numpy = False
if not isinstance(samples, torch.Tensor):
samples = torch.from_numpy(samples)
is_numpy = True

feats = log_mel_spectrogram(
samples,
device=self.device,
)

if is_numpy:
return feats.cpu().numpy()
else:
return feats

@staticmethod
def mix(
features_a: np.ndarray, features_b: np.ndarray, energy_scaling_factor_b: float
) -> np.ndarray:
return np.log(
np.maximum(
# protection against log(0); max with EPSILON is adequate since these are energies (always >= 0)
EPSILON,
np.exp(features_a) + energy_scaling_factor_b * np.exp(features_b),
)
)

@staticmethod
def compute_energy(features: np.ndarray) -> float:
return float(np.sum(np.exp(features)))
24 changes: 24 additions & 0 deletions test/features/test_whisper_fbank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from math import ceil

import numpy as np
import pytest

from lhotse.features.whisper_fbank import WhisperFbank, WhisperFbankConfig
from lhotse.utils import is_module_available


@pytest.mark.skipif(
not is_module_available("librosa"), reason="Librosa is an optional dependency."
)
@pytest.mark.parametrize("audio_len", [22050, 11025, 1024, 512, 24000, 16000])
def test_whisper_fbank_with_different_audio_lengths(audio_len):

extractor = WhisperFbank(WhisperFbankConfig(device="cpu"))

kernel_size = 400
stride = extractor.hop_length
pad = stride
expected_n_frames = ceil((audio_len - kernel_size + 2 * pad) / stride + 1)

n_frames = len(extractor.extract(np.zeros(audio_len, dtype=np.float32), 16000))
assert abs(n_frames - expected_n_frames) <= 1

0 comments on commit 3875788

Please sign in to comment.