diff --git a/CMakeLists.txt b/CMakeLists.txt index dbac3a0..8d289be 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,7 +25,7 @@ cmake_minimum_required(VERSION 3.3 FATAL_ERROR) project(kaldi-native-fbank CXX C) -set(KALDI_NATIVE_FBANK_VERSION "1.19.3") +set(KALDI_NATIVE_FBANK_VERSION "1.20.0") # Disable warning about # diff --git a/kaldi-native-fbank/csrc/whisper-feature.cc b/kaldi-native-fbank/csrc/whisper-feature.cc index e8ae5a2..1b69984 100644 --- a/kaldi-native-fbank/csrc/whisper-feature.cc +++ b/kaldi-native-fbank/csrc/whisper-feature.cc @@ -19,11 +19,11 @@ #include "kaldi-native-fbank/csrc/whisper-feature.h" #include -#include #include +#include -#include "kaldi-native-fbank/csrc/mel-computations.h" #include "kaldi-native-fbank/csrc/log.h" +#include "kaldi-native-fbank/csrc/mel-computations.h" #ifndef M_2PI #define M_2PI 6.283185307179586476925286766559005 @@ -31,6 +31,14 @@ namespace knf { +std::string WhisperFeatureOptions::ToString() const { + std::ostringstream os; + os << "WhisperFeatureOptions("; + os << "frame_opts=" << frame_opts.ToString() << ", "; + os << "dim=" << dim << ")"; + return os.str(); +} + static void dft(const std::vector &in, std::vector *out) { // this function is modified from // https://github.com/ggerganov/whisper.cpp/blob/master/whisper.cpp#L2353 @@ -116,23 +124,24 @@ static void fft(const std::vector &in, std::vector *out) { } WhisperFeatureComputer::WhisperFeatureComputer( - const FrameExtractionOptions & /*unused={}*/) { - frame_opts_.samp_freq = 16000; - frame_opts_.frame_shift_ms = 10; - frame_opts_.frame_length_ms = 25; - frame_opts_.dither = 0; - frame_opts_.preemph_coeff = 0; - frame_opts_.remove_dc_offset = false; - frame_opts_.window_type = "hann"; - frame_opts_.round_to_power_of_two = false; - frame_opts_.snip_edges = false; + const WhisperFeatureOptions &opts /*= {}*/) + : opts_(opts) { + opts_.frame_opts.samp_freq = 16000; + opts_.frame_opts.frame_shift_ms = 10; + opts_.frame_opts.frame_length_ms = 25; + opts_.frame_opts.dither = 0; + opts_.frame_opts.preemph_coeff = 0; + opts_.frame_opts.remove_dc_offset = false; + opts_.frame_opts.window_type = "hann"; + opts_.frame_opts.round_to_power_of_two = false; + opts_.frame_opts.snip_edges = false; MelBanksOptions mel_opts; - mel_opts.num_bins = 80; + mel_opts.num_bins = opts_.dim; mel_opts.low_freq = 0; mel_opts.is_librosa = true; - mel_banks_ = std::make_unique(mel_opts, frame_opts_, 1.0f); + mel_banks_ = std::make_unique(mel_opts, opts_.frame_opts, 1.0f); } void WhisperFeatureComputer::Compute(float /*signal_raw_log_energy*/, diff --git a/kaldi-native-fbank/csrc/whisper-feature.h b/kaldi-native-fbank/csrc/whisper-feature.h index 033bd7c..8ada4bd 100644 --- a/kaldi-native-fbank/csrc/whisper-feature.h +++ b/kaldi-native-fbank/csrc/whisper-feature.h @@ -19,23 +19,36 @@ #ifndef KALDI_NATIVE_FBANK_CSRC_WHISPER_FEATURE_H_ #define KALDI_NATIVE_FBANK_CSRC_WHISPER_FEATURE_H_ +#include #include #include -#include #include "kaldi-native-fbank/csrc/feature-window.h" #include "kaldi-native-fbank/csrc/mel-computations.h" namespace knf { +struct WhisperFeatureOptions { + WhisperFeatureOptions(const FrameExtractionOptions &frame_opts = {}, + int32_t dim = 80) + : frame_opts(frame_opts), dim(dim) {} + + FrameExtractionOptions frame_opts; + int32_t dim = 80; + + std::string ToString() const; +}; + class WhisperFeatureComputer { public: - explicit WhisperFeatureComputer( - const FrameExtractionOptions &unused_frame_opts_ = {}); + // note: opts.frame_opts is ignored and we reset it inside + explicit WhisperFeatureComputer(const WhisperFeatureOptions &opts = {}); - int32_t Dim() const { return 80; } + int32_t Dim() const { return opts_.dim; } - const FrameExtractionOptions &GetFrameOptions() const { return frame_opts_; } + const FrameExtractionOptions &GetFrameOptions() const { + return opts_.frame_opts; + } void Compute(float /*signal_raw_log_energy*/, float /*vtln_warp*/, std::vector *signal_frame, float *feature); @@ -43,11 +56,11 @@ class WhisperFeatureComputer { // if true, compute log_energy_pre_window but after dithering and dc removal bool NeedRawLogEnergy() const { return false; } - using Options = FrameExtractionOptions; + using Options = WhisperFeatureOptions; private: std::unique_ptr mel_banks_; - FrameExtractionOptions frame_opts_; + WhisperFeatureOptions opts_; }; } // namespace knf diff --git a/kaldi-native-fbank/python/csrc/online-feature.cc b/kaldi-native-fbank/python/csrc/online-feature.cc index d0bb0ec..b650c13 100644 --- a/kaldi-native-fbank/python/csrc/online-feature.cc +++ b/kaldi-native-fbank/python/csrc/online-feature.cc @@ -26,6 +26,7 @@ #include "kaldi-native-fbank/csrc/feature-mfcc.h" #include "kaldi-native-fbank/csrc/online-feature.h" #include "kaldi-native-fbank/csrc/whisper-feature.h" +#include "kaldi-native-fbank/python/csrc/utils.h" namespace pybind11 { class gil_scoped_release; @@ -33,6 +34,27 @@ class gil_scoped_release; namespace knf { +static void PybindWhisperFeatureOptions(py::module &m) { // NOLINT + using PyClass = WhisperFeatureOptions; + py::class_(m, "WhisperFeatureOptions") + .def(py::init<>()) + .def_readwrite("frame_opts", &PyClass::frame_opts) + .def_readwrite("dim", &PyClass::dim) + .def("__str__", + [](const PyClass &self) -> std::string { return self.ToString(); }) + .def("as_dict", + [](const PyClass &self) -> py::dict { return AsDict(self); }) + .def_static("from_dict", + [](py::dict dict) -> PyClass { + return WhisperFeatureOptionsFromDict(dict); + }) + .def(py::pickle( + [](const PyClass &self) -> py::dict { return AsDict(self); }, + [](py::dict dict) -> PyClass { + return WhisperFeatureOptionsFromDict(dict); + })); +} + template void PybindOnlineFeatureTpl(py::module &m, // NOLINT const std::string &class_name, @@ -76,6 +98,9 @@ void PybindOnlineFeatureTpl(py::module &m, // NOLINT void PybindOnlineFeature(py::module &m) { // NOLINT PybindOnlineFeatureTpl(m, "OnlineFbank"); PybindOnlineFeatureTpl(m, "OnlineMfcc"); + + PybindWhisperFeatureOptions(m); + PybindOnlineFeatureTpl(m, "OnlineWhisperFbank"); } diff --git a/kaldi-native-fbank/python/csrc/utils.cc b/kaldi-native-fbank/python/csrc/utils.cc index d7e259b..3a5cc40 100644 --- a/kaldi-native-fbank/python/csrc/utils.cc +++ b/kaldi-native-fbank/python/csrc/utils.cc @@ -23,6 +23,7 @@ #include "kaldi-native-fbank/csrc/feature-fbank.h" #include "kaldi-native-fbank/csrc/feature-mfcc.h" #include "kaldi-native-fbank/csrc/feature-window.h" +#include "kaldi-native-fbank/csrc/whisper-feature.h" #define FROM_DICT(type, key) \ if (dict.contains(#key)) { \ @@ -82,6 +83,7 @@ MelBanksOptions MelBanksOptionsFromDict(py::dict dict) { return opts; } + py::dict AsDict(const MelBanksOptions &opts) { py::dict dict; @@ -170,6 +172,27 @@ py::dict AsDict(const MfccOptions &opts) { return dict; } +WhisperFeatureOptions WhisperFeatureOptionsFromDict(py::dict dict) { + WhisperFeatureOptions opts; + + if (dict.contains("frame_opts")) { + opts.frame_opts = FrameExtractionOptionsFromDict(dict["frame_opts"]); + } + + FROM_DICT(int_, dim); + + return opts; +} + +py::dict AsDict(const WhisperFeatureOptions &opts) { + py::dict dict; + + dict["frame_opts"] = AsDict(opts.frame_opts); + AS_DICT(dim); + + return dict; +} + #undef FROM_DICT #undef AS_DICT diff --git a/kaldi-native-fbank/python/csrc/utils.h b/kaldi-native-fbank/python/csrc/utils.h index c273c85..ab6bee1 100644 --- a/kaldi-native-fbank/python/csrc/utils.h +++ b/kaldi-native-fbank/python/csrc/utils.h @@ -23,6 +23,7 @@ #include "kaldi-native-fbank/csrc/feature-mfcc.h" #include "kaldi-native-fbank/csrc/feature-window.h" #include "kaldi-native-fbank/csrc/mel-computations.h" +#include "kaldi-native-fbank/csrc/whisper-feature.h" #include "kaldi-native-fbank/python/csrc/kaldi-native-fbank.h" /* @@ -51,6 +52,9 @@ py::dict AsDict(const FbankOptions &opts); MfccOptions MfccOptionsFromDict(py::dict dict); py::dict AsDict(const MfccOptions &opts); +WhisperFeatureOptions WhisperFeatureOptionsFromDict(py::dict dict); +py::dict AsDict(const WhisperFeatureOptions &opts); + } // namespace knf #endif // KALDI_NATIVE_FBANK_PYTHON_CSRC_UTILS_H_ diff --git a/kaldi-native-fbank/python/kaldi_native_fbank/__init__.py b/kaldi-native-fbank/python/kaldi_native_fbank/__init__.py index d0b92f0..d336426 100644 --- a/kaldi-native-fbank/python/kaldi_native_fbank/__init__.py +++ b/kaldi-native-fbank/python/kaldi_native_fbank/__init__.py @@ -8,4 +8,5 @@ OnlineMfcc, OnlineWhisperFbank, Rfft, + WhisperFeatureOptions, ) diff --git a/kaldi-native-fbank/python/tests/test_online_whisper_fbank.py b/kaldi-native-fbank/python/tests/test_online_whisper_fbank.py index b170b9d..5219f4a 100755 --- a/kaldi-native-fbank/python/tests/test_online_whisper_fbank.py +++ b/kaldi-native-fbank/python/tests/test_online_whisper_fbank.py @@ -9,7 +9,10 @@ def test(): - opts = knf.FrameExtractionOptions() + opts = knf.WhisperFeatureOptions() + + # Use 128 for whisper large v3 + opts.dim = 128 online_whisper_fbank = knf.OnlineWhisperFbank(opts) audio = torch.rand(100000) @@ -36,7 +39,7 @@ def test(): mel = mel.t().unsqueeze(0) print(mel.shape) - assert mel.shape == (1, 80, 3000), mel.shape + assert mel.shape == (1, opts.dim, 3000), mel.shape # Now you can input 'mel' to whisper.encoder model