Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whisper] Computing features on GPU in batch mode for whisper feature extractor. #29900

Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
48 changes: 32 additions & 16 deletions src/transformers/models/whisper/feature_extraction_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,41 +94,57 @@ def __init__(
mel_scale="slaney",
)

def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
def _np_extract_fbank_features(self, waveform_batch: np.array) -> np.ndarray:
"""
Compute the log-mel spectrogram of the provided audio, gives similar results to Whisper's original torch
implementation with 1e-5 tolerance.
vaibhavagg303 marked this conversation as resolved.
Show resolved Hide resolved
"""
log_spec = spectrogram(
waveform,
window_function(self.n_fft, "hann"),
frame_length=self.n_fft,
hop_length=self.hop_length,
power=2.0,
mel_filters=self.mel_filters,
log_mel="log10",
)
log_spec = log_spec[:, :-1]
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec
log_spec_batch = []
vaibhavagg303 marked this conversation as resolved.
Show resolved Hide resolved
for waveform in waveform_batch:
log_spec = spectrogram(
waveform,
window_function(self.n_fft, "hann"),
frame_length=self.n_fft,
hop_length=self.hop_length,
power=2.0,
mel_filters=self.mel_filters,
log_mel="log10",
)
log_spec = log_spec[:, :-1]
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
log_spec_batch.append(log_spec)
log_spec_batch = np.array(log_spec_batch)
return log_spec_batch

def _torch_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
"""
Compute the log-mel spectrogram of the provided audio using the PyTorch STFT implementation.
"""
device = "cuda" if torch.cuda.is_available() else None
vaibhavagg303 marked this conversation as resolved.
Show resolved Hide resolved
waveform = torch.from_numpy(waveform).type(torch.float32)

window = torch.hann_window(self.n_fft)
if device is not None:
waveform = waveform.to(device)
window = window.to(device)
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2

mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
if device is not None:
mel_filters = mel_filters.to(device)
mel_spec = mel_filters.T @ magnitudes

log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
if waveform.dim() == 2:
max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
log_spec = torch.maximum(log_spec, max_val - 8.0)
else:
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
if device is not None:
log_spec = log_spec.detach().cpu()
return log_spec.numpy()

@staticmethod
Expand Down Expand Up @@ -272,7 +288,7 @@ def __call__(
extract_fbank_features = (
self._torch_extract_fbank_features if is_torch_available() else self._np_extract_fbank_features
)
input_features = [extract_fbank_features(waveform) for waveform in input_features[0]]
input_features = extract_fbank_features(input_features[0])

if isinstance(input_features[0], List):
padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
Expand Down