Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add whisper feature extractor #1159

Merged
merged 4 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lhotse/features/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@
from .opensmile import OpenSmileConfig, OpenSmileExtractor
from .spectrogram import TorchaudioSpectrogram, TorchaudioSpectrogramConfig
from .ssl import S3PRLSSL, S3PRLSSLConfig
from .whisper_fbank import WhisperFbank, WhisperFbankConfig
161 changes: 161 additions & 0 deletions lhotse/features/whisper_fbank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union

import numpy as np
import torch

from lhotse.features.base import FeatureExtractor, register_extractor
from lhotse.utils import (
EPSILON,
Seconds,
asdict_nonull,
compute_num_frames_from_samples,
is_module_available,
)


def log_mel_spectrogram(
audio: Union[np.ndarray, torch.Tensor],
n_mels: int = 80,
n_fft: int = 400,
hop_length: int = 160,
sampling_rate: int = 16000,
device: Optional[Union[str, torch.device]] = None,
):
"""
From https://github.com/openai/whisper/blob/main/whisper/audio.py

Compute the log-Mel spectrogram of

Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz

n_mels: int
The number of Mel-frequency filters, only 80 is supported

device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT

Returns
-------
torch.Tensor, shape = (n_frames, 80)
A Tensor that contains the Mel spectrogram
"""
if is_module_available("librosa"):
import librosa

Check warning on line 47 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L46-L47

Added lines #L46 - L47 were not covered by tests
else:
raise ImportError(

Check warning on line 49 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L49

Added line #L49 was not covered by tests
"Librosa is not installed. Please install librosa before using LibrosaFbank extractor."
)
if not torch.is_tensor(audio):
audio = torch.from_numpy(audio)

Check warning on line 53 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L52-L53

Added lines #L52 - L53 were not covered by tests

if device is not None:
audio = audio.to(device)
audio = audio.squeeze(0)
window = torch.hann_window(n_fft).to(audio.device)
stft = torch.stft(audio, n_fft, hop_length, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2

Check warning on line 60 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L55-L60

Added lines #L55 - L60 were not covered by tests

filters = librosa.filters.mel(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels)
filters = torch.from_numpy(filters).to(device)
mel_spec = filters @ magnitudes

Check warning on line 64 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L62-L64

Added lines #L62 - L64 were not covered by tests

log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0

Check warning on line 68 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L66-L68

Added lines #L66 - L68 were not covered by tests

padding = compute_num_frames_from_samples(

Check warning on line 70 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L70

Added line #L70 was not covered by tests
num_samples=len(audio),
frame_shift=hop_length / sampling_rate,
sampling_rate=sampling_rate,
)
if padding > log_spec.shape[1]:
log_spec = torch.nn.functional.pad(

Check warning on line 76 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L75-L76

Added lines #L75 - L76 were not covered by tests
log_spec, (0, padding - log_spec.shape[1]), mode="constant"
)
# change shape from 80, n_frames to n_frames,80
log_spec = log_spec.transpose(0, 1)

Check warning on line 80 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L80

Added line #L80 was not covered by tests

return log_spec

Check warning on line 82 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L82

Added line #L82 was not covered by tests


@dataclass
class WhisperFbankConfig:
device: str = "cpu"

def to_dict(self) -> Dict[str, Any]:
return asdict_nonull(self)

Check warning on line 90 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L90

Added line #L90 was not covered by tests

@staticmethod
def from_dict(data: Dict[str, Any]) -> "WhisperFbankConfig":
return WhisperFbankConfig(**data)

Check warning on line 94 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L94

Added line #L94 was not covered by tests


@register_extractor
class WhisperFbank(FeatureExtractor):
name = "whisper-fbank"
config_type = WhisperFbankConfig

def __init__(self, config: Optional[WhisperFbankConfig] = None):
super().__init__(config=config)
self.sampling_rate = 16000
self.num_filters = 80
self.hop_length = 160

Check warning on line 106 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L103-L106

Added lines #L103 - L106 were not covered by tests

@property
def device(self) -> Union[str, torch.device]:
return self.config.device

Check warning on line 110 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L110

Added line #L110 was not covered by tests

@property
def frame_shift(self) -> Seconds:
return self.hop_length / self.sampling_rate

Check warning on line 114 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L114

Added line #L114 was not covered by tests

def to(self, device: str):
self.config.device = device

Check warning on line 117 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L117

Added line #L117 was not covered by tests

def feature_dim(self, sampling_rate: int) -> int:
return self.num_filters

Check warning on line 120 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L120

Added line #L120 was not covered by tests

def extract(
self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int
) -> Union[np.ndarray, torch.Tensor]:
assert sampling_rate == self.sampling_rate, (

Check warning on line 125 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L125

Added line #L125 was not covered by tests
f"Fbank was instantiated for sampling_rate "
f"{self.sampling_rate}, but "
f"sampling_rate={sampling_rate} was passed to extract(). "
"Note you can use CutSet/RecordingSet.resample() to change the audio sampling rate."
)

is_numpy = False
if not isinstance(samples, torch.Tensor):
samples = torch.from_numpy(samples)
is_numpy = True

Check warning on line 135 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L132-L135

Added lines #L132 - L135 were not covered by tests

feats = log_mel_spectrogram(

Check warning on line 137 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L137

Added line #L137 was not covered by tests
samples,
device=self.device,
)

if is_numpy:
return feats.cpu().numpy()

Check warning on line 143 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L142-L143

Added lines #L142 - L143 were not covered by tests
else:
return feats

Check warning on line 145 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L145

Added line #L145 was not covered by tests

@staticmethod
def mix(
features_a: np.ndarray, features_b: np.ndarray, energy_scaling_factor_b: float
) -> np.ndarray:
return np.log(

Check warning on line 151 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L151

Added line #L151 was not covered by tests
np.maximum(
# protection against log(0); max with EPSILON is adequate since these are energies (always >= 0)
EPSILON,
np.exp(features_a) + energy_scaling_factor_b * np.exp(features_b),
)
)

@staticmethod
def compute_energy(features: np.ndarray) -> float:
return float(np.sum(np.exp(features)))

Check warning on line 161 in lhotse/features/whisper_fbank.py

View check run for this annotation

Codecov / codecov/patch

lhotse/features/whisper_fbank.py#L161

Added line #L161 was not covered by tests
24 changes: 24 additions & 0 deletions test/features/test_whisper_fbank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from math import ceil

import numpy as np
import pytest

from lhotse.features.whisper_fbank import WhisperFbank, WhisperFbankConfig
from lhotse.utils import is_module_available


@pytest.mark.skipif(
not is_module_available("librosa"), reason="Librosa is an optional dependency."
)
@pytest.mark.parametrize("audio_len", [22050, 11025, 1024, 512, 24000, 16000])
def test_whisper_fbank_with_different_audio_lengths(audio_len):

extractor = WhisperFbank(WhisperFbankConfig(device="cpu"))

kernel_size = 400
stride = extractor.hop_length
pad = stride
expected_n_frames = ceil((audio_len - kernel_size + 2 * pad) / stride + 1)

n_frames = len(extractor.extract(np.zeros(audio_len, dtype=np.float32), 16000))
assert abs(n_frames - expected_n_frames) <= 1
Loading