Skip to content

Commit

Permalink
support whisper large v3 (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Jun 25, 2024
1 parent a97263c commit fdc395d
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 24 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ cmake_minimum_required(VERSION 3.3 FATAL_ERROR)

project(kaldi-native-fbank CXX C)

set(KALDI_NATIVE_FBANK_VERSION "1.19.3")
set(KALDI_NATIVE_FBANK_VERSION "1.20.0")

# Disable warning about
#
Expand Down
37 changes: 23 additions & 14 deletions kaldi-native-fbank/csrc/whisper-feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,26 @@
#include "kaldi-native-fbank/csrc/whisper-feature.h"

#include <cmath>
#include <vector>
#include <string>
#include <vector>

#include "kaldi-native-fbank/csrc/mel-computations.h"
#include "kaldi-native-fbank/csrc/log.h"
#include "kaldi-native-fbank/csrc/mel-computations.h"

#ifndef M_2PI
#define M_2PI 6.283185307179586476925286766559005
#endif

namespace knf {

std::string WhisperFeatureOptions::ToString() const {
std::ostringstream os;
os << "WhisperFeatureOptions(";
os << "frame_opts=" << frame_opts.ToString() << ", ";
os << "dim=" << dim << ")";
return os.str();
}

static void dft(const std::vector<float> &in, std::vector<float> *out) {
// this function is modified from
// https://github.com/ggerganov/whisper.cpp/blob/master/whisper.cpp#L2353
Expand Down Expand Up @@ -116,23 +124,24 @@ static void fft(const std::vector<float> &in, std::vector<float> *out) {
}

WhisperFeatureComputer::WhisperFeatureComputer(
const FrameExtractionOptions & /*unused={}*/) {
frame_opts_.samp_freq = 16000;
frame_opts_.frame_shift_ms = 10;
frame_opts_.frame_length_ms = 25;
frame_opts_.dither = 0;
frame_opts_.preemph_coeff = 0;
frame_opts_.remove_dc_offset = false;
frame_opts_.window_type = "hann";
frame_opts_.round_to_power_of_two = false;
frame_opts_.snip_edges = false;
const WhisperFeatureOptions &opts /*= {}*/)
: opts_(opts) {
opts_.frame_opts.samp_freq = 16000;
opts_.frame_opts.frame_shift_ms = 10;
opts_.frame_opts.frame_length_ms = 25;
opts_.frame_opts.dither = 0;
opts_.frame_opts.preemph_coeff = 0;
opts_.frame_opts.remove_dc_offset = false;
opts_.frame_opts.window_type = "hann";
opts_.frame_opts.round_to_power_of_two = false;
opts_.frame_opts.snip_edges = false;

MelBanksOptions mel_opts;
mel_opts.num_bins = 80;
mel_opts.num_bins = opts_.dim;
mel_opts.low_freq = 0;
mel_opts.is_librosa = true;

mel_banks_ = std::make_unique<MelBanks>(mel_opts, frame_opts_, 1.0f);
mel_banks_ = std::make_unique<MelBanks>(mel_opts, opts_.frame_opts, 1.0f);
}

void WhisperFeatureComputer::Compute(float /*signal_raw_log_energy*/,
Expand Down
27 changes: 20 additions & 7 deletions kaldi-native-fbank/csrc/whisper-feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,48 @@
#ifndef KALDI_NATIVE_FBANK_CSRC_WHISPER_FEATURE_H_
#define KALDI_NATIVE_FBANK_CSRC_WHISPER_FEATURE_H_

#include <cstdint>
#include <memory>
#include <vector>
#include <cstdint>

#include "kaldi-native-fbank/csrc/feature-window.h"
#include "kaldi-native-fbank/csrc/mel-computations.h"

namespace knf {

struct WhisperFeatureOptions {
WhisperFeatureOptions(const FrameExtractionOptions &frame_opts = {},
int32_t dim = 80)
: frame_opts(frame_opts), dim(dim) {}

FrameExtractionOptions frame_opts;
int32_t dim = 80;

std::string ToString() const;
};

class WhisperFeatureComputer {
public:
explicit WhisperFeatureComputer(
const FrameExtractionOptions &unused_frame_opts_ = {});
// note: opts.frame_opts is ignored and we reset it inside
explicit WhisperFeatureComputer(const WhisperFeatureOptions &opts = {});

int32_t Dim() const { return 80; }
int32_t Dim() const { return opts_.dim; }

const FrameExtractionOptions &GetFrameOptions() const { return frame_opts_; }
const FrameExtractionOptions &GetFrameOptions() const {
return opts_.frame_opts;
}

void Compute(float /*signal_raw_log_energy*/, float /*vtln_warp*/,
std::vector<float> *signal_frame, float *feature);

// if true, compute log_energy_pre_window but after dithering and dc removal
bool NeedRawLogEnergy() const { return false; }

using Options = FrameExtractionOptions;
using Options = WhisperFeatureOptions;

private:
std::unique_ptr<MelBanks> mel_banks_;
FrameExtractionOptions frame_opts_;
WhisperFeatureOptions opts_;
};

} // namespace knf
Expand Down
25 changes: 25 additions & 0 deletions kaldi-native-fbank/python/csrc/online-feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,35 @@
#include "kaldi-native-fbank/csrc/feature-mfcc.h"
#include "kaldi-native-fbank/csrc/online-feature.h"
#include "kaldi-native-fbank/csrc/whisper-feature.h"
#include "kaldi-native-fbank/python/csrc/utils.h"

namespace pybind11 {
class gil_scoped_release;
} // namespace pybind11

namespace knf {

static void PybindWhisperFeatureOptions(py::module &m) { // NOLINT
using PyClass = WhisperFeatureOptions;
py::class_<PyClass>(m, "WhisperFeatureOptions")
.def(py::init<>())
.def_readwrite("frame_opts", &PyClass::frame_opts)
.def_readwrite("dim", &PyClass::dim)
.def("__str__",
[](const PyClass &self) -> std::string { return self.ToString(); })
.def("as_dict",
[](const PyClass &self) -> py::dict { return AsDict(self); })
.def_static("from_dict",
[](py::dict dict) -> PyClass {
return WhisperFeatureOptionsFromDict(dict);
})
.def(py::pickle(
[](const PyClass &self) -> py::dict { return AsDict(self); },
[](py::dict dict) -> PyClass {
return WhisperFeatureOptionsFromDict(dict);
}));
}

template <typename C>
void PybindOnlineFeatureTpl(py::module &m, // NOLINT
const std::string &class_name,
Expand Down Expand Up @@ -76,6 +98,9 @@ void PybindOnlineFeatureTpl(py::module &m, // NOLINT
void PybindOnlineFeature(py::module &m) { // NOLINT
PybindOnlineFeatureTpl<FbankComputer>(m, "OnlineFbank");
PybindOnlineFeatureTpl<MfccComputer>(m, "OnlineMfcc");

PybindWhisperFeatureOptions(m);

PybindOnlineFeatureTpl<WhisperFeatureComputer>(m, "OnlineWhisperFbank");
}

Expand Down
23 changes: 23 additions & 0 deletions kaldi-native-fbank/python/csrc/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "kaldi-native-fbank/csrc/feature-fbank.h"
#include "kaldi-native-fbank/csrc/feature-mfcc.h"
#include "kaldi-native-fbank/csrc/feature-window.h"
#include "kaldi-native-fbank/csrc/whisper-feature.h"

#define FROM_DICT(type, key) \
if (dict.contains(#key)) { \
Expand Down Expand Up @@ -82,6 +83,7 @@ MelBanksOptions MelBanksOptionsFromDict(py::dict dict) {

return opts;
}

py::dict AsDict(const MelBanksOptions &opts) {
py::dict dict;

Expand Down Expand Up @@ -170,6 +172,27 @@ py::dict AsDict(const MfccOptions &opts) {
return dict;
}

WhisperFeatureOptions WhisperFeatureOptionsFromDict(py::dict dict) {
WhisperFeatureOptions opts;

if (dict.contains("frame_opts")) {
opts.frame_opts = FrameExtractionOptionsFromDict(dict["frame_opts"]);
}

FROM_DICT(int_, dim);

return opts;
}

py::dict AsDict(const WhisperFeatureOptions &opts) {
py::dict dict;

dict["frame_opts"] = AsDict(opts.frame_opts);
AS_DICT(dim);

return dict;
}

#undef FROM_DICT
#undef AS_DICT

Expand Down
4 changes: 4 additions & 0 deletions kaldi-native-fbank/python/csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "kaldi-native-fbank/csrc/feature-mfcc.h"
#include "kaldi-native-fbank/csrc/feature-window.h"
#include "kaldi-native-fbank/csrc/mel-computations.h"
#include "kaldi-native-fbank/csrc/whisper-feature.h"
#include "kaldi-native-fbank/python/csrc/kaldi-native-fbank.h"

/*
Expand Down Expand Up @@ -51,6 +52,9 @@ py::dict AsDict(const FbankOptions &opts);
MfccOptions MfccOptionsFromDict(py::dict dict);
py::dict AsDict(const MfccOptions &opts);

WhisperFeatureOptions WhisperFeatureOptionsFromDict(py::dict dict);
py::dict AsDict(const WhisperFeatureOptions &opts);

} // namespace knf

#endif // KALDI_NATIVE_FBANK_PYTHON_CSRC_UTILS_H_
1 change: 1 addition & 0 deletions kaldi-native-fbank/python/kaldi_native_fbank/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@
OnlineMfcc,
OnlineWhisperFbank,
Rfft,
WhisperFeatureOptions,
)
7 changes: 5 additions & 2 deletions kaldi-native-fbank/python/tests/test_online_whisper_fbank.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@


def test():
opts = knf.FrameExtractionOptions()
opts = knf.WhisperFeatureOptions()

# Use 128 for whisper large v3
opts.dim = 128
online_whisper_fbank = knf.OnlineWhisperFbank(opts)

audio = torch.rand(100000)
Expand All @@ -36,7 +39,7 @@ def test():
mel = mel.t().unsqueeze(0)
print(mel.shape)

assert mel.shape == (1, 80, 3000), mel.shape
assert mel.shape == (1, opts.dim, 3000), mel.shape
# Now you can input 'mel' to whisper.encoder model


Expand Down

0 comments on commit fdc395d

Please sign in to comment.