Skip to content

Commit

Permalink
simplify & speed up the STFT
Browse files Browse the repository at this point in the history
  • Loading branch information
hollance committed Mar 6, 2023
1 parent de19c7a commit 2227cd5
Showing 1 changed file with 27 additions and 64 deletions.
91 changes: 27 additions & 64 deletions src/transformers/models/speecht5/feature_extraction_speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,11 @@ def __init__(

self.sample_size = win_length * sampling_rate // 1000
self.sample_stride = hop_length * sampling_rate // 1000

self.n_fft = 2 ** int(np.ceil(np.log2(self.sample_size)))
self.n_freqs = (self.n_fft // 2) + 1

window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True)
self.window = window.numpy()
self.window = window.numpy().astype(np.float64)

self.mel_filter_banks = get_mel_filter_banks(
nb_frequency_bins=self.n_freqs,
Expand Down Expand Up @@ -146,79 +145,43 @@ def zero_mean_unit_var_norm(
return normed_input_values

@staticmethod
def _center_pad(one_waveform, n_fft, pad_mode):
padding = [(int(n_fft // 2), int(n_fft // 2))]
return np.pad(one_waveform, padding, mode=pad_mode)

@staticmethod
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._num_frames_calc
def _num_frames_calc(in_size, frame_size, frame_stride):
return int(1 + np.floor((in_size - frame_size) * 1 / frame_stride))

@staticmethod
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._frame_signal
def _frame_signal(one_waveform, n_frames, frame_signal_scale, window_length, sample_stride):
scale = frame_signal_scale
frames = np.zeros(n_frames * window_length)
for frame_idx in range(n_frames):
start = frame_idx * window_length
end = (frame_idx + 1) * window_length
wave_start = frame_idx * sample_stride
wave_end = frame_idx * sample_stride + window_length
frames[start:end] = scale * one_waveform[wave_start:wave_end]

return frames

@staticmethod
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._windowing
def _windowing(frames, window_length, window):
if frames.size % window_length != 0:
raise ValueError(
f"`frames` is supposed to have length divisble by `window_length`, but is {frames.size} with"
f" window_length={window_length}."
)

shaped = frames.reshape(-1, window_length)
shaped = window * shaped
return shaped

@staticmethod
# Copied from transformers.models.mctct.feature_extraction_mctct.MCTCTFeatureExtractor._dft
def _dft(frames, K, n_frames, n_samples, n_fft):
dft = np.zeros([n_frames, K])
def _stft(waveform: np.ndarray, fft_length: int, hop_length: int, window: np.ndarray) -> np.ndarray:
"""
Calculates the magnitude spectrogram over one waveform array.
"""
# center pad the waveform
padding = [(int(fft_length // 2), int(fft_length // 2))]
waveform = np.pad(waveform, padding, mode="reflect")
waveform_size = waveform.size

for frame in range(n_frames):
begin = frame * n_samples
# promote to float64, since np.fft uses float64 internally
waveform = waveform.astype(np.float64)

inwards_buffer = frames[begin : begin + n_samples]
inwards_buffer = np.pad(inwards_buffer, (0, n_fft - n_samples), "constant")
out = np.fft.rfft(inwards_buffer)
num_frames = int(1 + np.floor((waveform_size - fft_length) / hop_length))
num_frequency_bins = (fft_length // 2) + 1
spectrogram = np.empty((num_frames, num_frequency_bins))

dft[frame] = np.abs(out[:K])
start = 0
for frame_idx in range(num_frames):
frame = waveform[start : start + fft_length] * window
spectrogram[frame_idx] = np.abs(np.fft.rfft(frame))
start += hop_length

return dft
return spectrogram

def _extract_fbank_features(
def _extract_mel_features(
self,
one_waveform: np.ndarray,
) -> np.ndarray:
"""
Extracts log-mel filterbank features for one waveform vector (unbatched). Adapted from Flashlight's C++ MFSC
code and librosa.
Extracts log-mel filterbank features for one waveform array (unbatched).
"""
one_waveform = self._center_pad(one_waveform, self.n_fft, "reflect")

n_frames = self._num_frames_calc(one_waveform.size, self.sample_size, self.sample_stride)

frames = self._frame_signal(
one_waveform, n_frames, self.frame_signal_scale, self.sample_size, self.sample_stride
)

frames = self._windowing(frames, self.sample_size, self.window)
if self.n_fft != self.sample_size:
raise NotImplementedError("Currently the STFT frame size must be a power of two.")

dft_out = self._dft(frames.flatten(), self.n_freqs, n_frames, self.sample_size, self.n_fft)
stft_out = self._stft(one_waveform, self.n_fft, self.sample_stride, self.window)

return np.log10(np.maximum(self.mel_floor, np.dot(dft_out, self.mel_filter_banks)))
return np.log10(np.maximum(self.mel_floor, np.dot(stft_out, self.mel_filter_banks)))

def __call__(
self,
Expand Down Expand Up @@ -369,7 +332,7 @@ def _process_audio(

# convert into correct format for padding
if is_target:
features = [self._extract_fbank_features(waveform) for waveform in speech]
features = [self._extract_mel_features(waveform) for waveform in speech]
encoded_inputs = BatchFeature({"input_values": features})
self.feature_size = self.num_mel_bins
else:
Expand Down

0 comments on commit 2227cd5

Please sign in to comment.