Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whisper] Computing features on GPU in batch mode for whisper feature extractor. #29900

Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 27 additions & 2 deletions src/transformers/models/whisper/feature_extraction_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,27 @@ def _torch_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
log_spec = (log_spec + 4.0) / 4.0
return log_spec.numpy()

def _torch_extract_fbank_features_batch(self, waveforms: np.array, device: str) -> np.ndarray:
vaibhavagg303 marked this conversation as resolved.
Show resolved Hide resolved
"""
Compute the log-mel spectrogram of the provided audio using the PyTorch STFT implementation.
"""
waveforms = torch.from_numpy(waveforms).type(torch.float32).to(device)

window = torch.hann_window(self.n_fft).to(device)
stft = torch.stft(waveforms, self.n_fft, self.hop_length, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2

mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32).to(device)
mel_spec = mel_filters.T @ magnitudes

log_spec = torch.clamp(mel_spec, min=1e-10).log10()
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
log_spec = torch.maximum(log_spec, max_val - 8.0)
log_spec = (log_spec + 4.0) / 4.0
if device != 'cpu':
log_spec = log_spec.detach().cpu()
return log_spec.numpy()

@staticmethod
# Copied from transformers.models.wav2vec2.feature_extraction_wav2vec2.Wav2Vec2FeatureExtractor.zero_mean_unit_var_norm
def zero_mean_unit_var_norm(
Expand Down Expand Up @@ -165,6 +186,7 @@ def __call__(
max_length: Optional[int] = None,
sampling_rate: Optional[int] = None,
do_normalize: Optional[bool] = None,
device: Optional[str] = 'cpu',
**kwargs,
) -> BatchFeature:
"""
Expand Down Expand Up @@ -270,9 +292,12 @@ def __call__(
input_features = padded_inputs.get("input_features").transpose(2, 0, 1)

extract_fbank_features = (
self._torch_extract_fbank_features if is_torch_available() else self._np_extract_fbank_features
self._torch_extract_fbank_features_batch if is_torch_available() else self._np_extract_fbank_features
)
input_features = [extract_fbank_features(waveform) for waveform in input_features[0]]
if is_torch_available():
input_features = extract_fbank_features(input_features[0], device)
else:
input_features = [extract_fbank_features(waveform) for waveform in input_features[0]]
vaibhavagg303 marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(input_features[0], List):
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
Expand Down