Skip to content

Commit

Permalink
[Whisper] Computing features on GPU in batch mode for whisper feature…
Browse files Browse the repository at this point in the history
… extractor. (#29900)

* add _torch_extract_fbank_features_batch function in feature_extractor_whisper

* reformat feature_extraction_whisper.py file

* handle batching in single function

* add gpu test & doc

* add batch test & device in each __call__

* add device arg in doc string

---------

Co-authored-by: vaibhav.aggarwal <vaibhav.aggarwal@sprinklr.com>
  • Loading branch information
2 people authored and Ita Zaporozhets committed May 14, 2024
1 parent 01a28c4 commit 6274bc0
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 20 deletions.
64 changes: 45 additions & 19 deletions src/transformers/models/whisper/feature_extraction_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,41 +94,63 @@ def __init__(
mel_scale="slaney",
)

def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
def _np_extract_fbank_features(self, waveform_batch: np.array, device: str) -> np.ndarray:
"""
Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch
implementation with 1e-5 tolerance.
"""
log_spec = spectrogram(
waveform,
window_function(self.n_fft, "hann"),
frame_length=self.n_fft,
hop_length=self.hop_length,
power=2.0,
mel_filters=self.mel_filters,
log_mel="log10",
)
log_spec = log_spec[:, :-1]
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec

def _torch_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
if device != "cpu":
raise ValueError(
f"Got device `{device}` for feature extraction, but feature extraction on CUDA accelerator "
"devices requires torch, which is not installed. Either set `device='cpu'`, or "
"install torch according to the official instructions: https://pytorch.org/get-started/locally/"
)
log_spec_batch = []
for waveform in waveform_batch:
log_spec = spectrogram(
waveform,
window_function(self.n_fft, "hann"),
frame_length=self.n_fft,
hop_length=self.hop_length,
power=2.0,
mel_filters=self.mel_filters,
log_mel="log10",
)
log_spec = log_spec[:, :-1]
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
log_spec_batch.append(log_spec)
log_spec_batch = np.array(log_spec_batch)
return log_spec_batch

def _torch_extract_fbank_features(self, waveform: np.array, device: str = "cpu") -> np.ndarray:
"""
Compute the log-mel spectrogram of the provided audio using the PyTorch STFT implementation.
Compute the log-mel spectrogram of the audio using PyTorch's GPU-accelerated STFT implementation with batching,
yielding results similar to cpu computing with 1e-5 tolerance.
"""
waveform = torch.from_numpy(waveform).type(torch.float32)

window = torch.hann_window(self.n_fft)
if device != "cpu":
waveform = waveform.to(device)
window = window.to(device)
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2

mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
if device != "cpu":
mel_filters = mel_filters.to(device)
mel_spec = mel_filters.T @ magnitudes

log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
if waveform.dim() == 2:
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
log_spec = torch.maximum(log_spec, max_val - 8.0)
else:
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
if device != "cpu":
log_spec = log_spec.detach().cpu()
return log_spec.numpy()

@staticmethod
Expand Down Expand Up @@ -165,6 +187,7 @@ def __call__(
max_length: Optional[int] = None,
sampling_rate: Optional[int] = None,
do_normalize: Optional[bool] = None,
device: Optional[str] = "cpu",
**kwargs,
) -> BatchFeature:
"""
Expand Down Expand Up @@ -211,6 +234,9 @@ def __call__(
do_normalize (`bool`, *optional*, defaults to `False`):
Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
improve the performance of the model.
device (`str`, *optional*, defaults to `'cpu'`):
Specifies the device for computation of the log-mel spectrogram of audio signals in the
`_torch_extract_fbank_features` method. (e.g., "cpu", "cuda")
"""

if sampling_rate is not None:
Expand Down Expand Up @@ -272,7 +298,7 @@ def __call__(
extract_fbank_features = (
self._torch_extract_fbank_features if is_torch_available() else self._np_extract_fbank_features
)
input_features = [extract_fbank_features(waveform) for waveform in input_features[0]]
input_features = extract_fbank_features(input_features[0], device)

if isinstance(input_features[0], List):
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
Expand Down
38 changes: 37 additions & 1 deletion tests/models/whisper/test_feature_extraction_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from datasets import load_dataset

from transformers import WhisperFeatureExtractor
from transformers.testing_utils import check_json_file_has_correct_format, require_torch
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_torch_gpu
from transformers.utils.import_utils import is_torch_available

from ...test_sequence_feature_extraction_common import SequenceFeatureExtractionTestMixin
Expand Down Expand Up @@ -207,6 +207,7 @@ def _load_datasamples(self, num_samples):

return [x["array"] for x in speech_samples]

@require_torch_gpu
@require_torch
def test_torch_integration(self):
# fmt: off
Expand All @@ -223,6 +224,7 @@ def test_torch_integration(self):
input_speech = self._load_datasamples(1)
feature_extractor = WhisperFeatureExtractor()
input_features = feature_extractor(input_speech, return_tensors="pt").input_features

self.assertEqual(input_features.shape, (1, 80, 3000))
self.assertTrue(torch.allclose(input_features[0, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))

Expand Down Expand Up @@ -253,3 +255,37 @@ def test_zero_mean_unit_variance_normalization_trunc_np_longest(self):

self.assertTrue(np.all(np.mean(audio) < 1e-3))
self.assertTrue(np.all(np.abs(np.var(audio) - 1) < 1e-3))

@require_torch_gpu
@require_torch
def test_torch_integration_batch(self):
# fmt: off
EXPECTED_INPUT_FEATURES = torch.tensor(
[
[
0.1193, -0.0946, -0.1098, -0.0196, 0.0225, -0.0690, -0.1736, 0.0951,
0.0971, -0.0817, -0.0702, 0.0162, 0.0260, 0.0017, -0.0192, -0.1678,
0.0709, -0.1867, -0.0655, -0.0274, -0.0234, -0.1884, -0.0516, -0.0554,
-0.0274, -0.1425, -0.1423, 0.0837, 0.0377, -0.0854
],
[
-0.4696, -0.0751, 0.0276, -0.0312, -0.0540, -0.0383, 0.1295, 0.0568,
-0.2071, -0.0548, 0.0389, -0.0316, -0.2346, -0.1068, -0.0322, 0.0475,
-0.1709, -0.0041, 0.0872, 0.0537, 0.0075, -0.0392, 0.0371, 0.0189,
-0.1522, -0.0270, 0.0744, 0.0738, -0.0245, -0.0667
],
[
-0.2337, -0.0060, -0.0063, -0.2353, -0.0431, 0.1102, -0.1492, -0.0292,
0.0787, -0.0608, 0.0143, 0.0582, 0.0072, 0.0101, -0.0444, -0.1701,
-0.0064, -0.0027, -0.0826, -0.0730, -0.0099, -0.0762, -0.0170, 0.0446,
-0.1153, 0.0960, -0.0361, 0.0652, 0.1207, 0.0277
]
]
)
# fmt: on

input_speech = self._load_datasamples(3)
feature_extractor = WhisperFeatureExtractor()
input_features = feature_extractor(input_speech, return_tensors="pt").input_features
self.assertEqual(input_features.shape, (3, 80, 3000))
self.assertTrue(torch.allclose(input_features[:, 0, :30], EXPECTED_INPUT_FEATURES, atol=1e-4))

0 comments on commit 6274bc0

Please sign in to comment.