Skip to content

Commit

Permalink
don't serialize window and mel filters
Browse files Browse the repository at this point in the history
  • Loading branch information
hollance committed Mar 7, 2023
1 parent 4737fcd commit 52d0c2f
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions src/transformers/models/speecht5/feature_extraction_speecht5.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""Feature extractor class for SpeechT5."""

from typing import List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -111,7 +111,7 @@ def __init__(
window = getattr(torch, self.win_function)(window_length=self.sample_size, periodic=True)
self.window = window.numpy().astype(np.float64)

self.mel_filter_banks = get_mel_filter_banks(
self.mel_filters = get_mel_filter_banks(
nb_frequency_bins=self.n_freqs,
nb_mel_filters=self.num_mel_bins,
frequency_min=self.fmin,
Expand Down Expand Up @@ -181,7 +181,7 @@ def _extract_mel_features(

stft_out = self._stft(one_waveform, self.n_fft, self.sample_stride, self.window)

return np.log10(np.maximum(self.mel_floor, np.dot(stft_out, self.mel_filter_banks)))
return np.log10(np.maximum(self.mel_floor, np.dot(stft_out, self.mel_filters)))

def __call__(
self,
Expand Down Expand Up @@ -383,3 +383,14 @@ def _process_audio(
padded_inputs = padded_inputs.convert_to_tensors(return_tensors)

return padded_inputs

def to_dict(self) -> Dict[str, Any]:
output = super().to_dict()

# Don't serialize these as they are derived from the other properties.
names = ["window", "mel_filters", "sample_size", "sample_stride", "n_fft", "n_freqs"]
for name in names:
if name in output:
del output[name]

return output

0 comments on commit 52d0c2f

Please sign in to comment.