Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support computing features for whisper #82

Merged
merged 3 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ project(kaldifeat)
# remember to change the version in
# scripts/conda/kaldifeat/meta.yaml
# scripts/conda-cpu/kaldifeat/meta.yaml
set(kaldifeat_VERSION "1.25.1")
set(kaldifeat_VERSION "1.25.2")

set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
Expand Down
1 change: 1 addition & 0 deletions kaldifeat/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ set(kaldifeat_srcs
matrix-functions.cc
mel-computations.cc
online-feature.cc
whisper-fbank.cc
)

add_library(kaldifeat_core ${kaldifeat_srcs})
Expand Down
1 change: 1 addition & 0 deletions kaldifeat/csrc/CPPLINT.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
exclude_files=whisper-mel-bank.h
2 changes: 1 addition & 1 deletion kaldifeat/csrc/feature-fbank.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ torch::Tensor FbankComputer::Compute(torch::Tensor signal_raw_log_energy,
// note spectrum is in magnitude, not power, because of `abs()`
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
// signal_frame shape: [x, 512]
// spectrum shape [x, 257
// spectrum shape [x, 257]
torch::Tensor spectrum = torch::fft::rfft(signal_frame).abs();
#else
// signal_frame shape [x, 512]
Expand Down
9 changes: 9 additions & 0 deletions kaldifeat/csrc/feature-window.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts,
float *window_data = window.data_ptr<float>();

double a = M_2PI / (frame_length - 1);

if (opts.window_type == "hann") {
// see https://pytorch.org/docs/stable/generated/torch.hann_window.html
// We assume periodic is true
a = M_2PI / frame_length;
}

for (int32_t i = 0; i < frame_length; i++) {
double i_fl = static_cast<double>(i);
if (opts.window_type == "hanning") {
Expand All @@ -39,6 +46,8 @@ FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts,
window_data[i] = sin(0.5 * a * i_fl);
} else if (opts.window_type == "hamming") {
window_data[i] = 0.54 - 0.46 * cos(a * i_fl);
} else if (opts.window_type == "hann") {
window_data[i] = 0.50 - 0.50 * cos(a * i_fl);
} else if (opts.window_type ==
"povey") { // like hamming but goes to zero at edges.
window_data[i] = pow(0.5 - 0.5 * cos(a * i_fl), 0.85);
Expand Down
39 changes: 39 additions & 0 deletions kaldifeat/csrc/generate-whisper-melbank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/usr/bin/env python3

# Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)

import librosa
import numpy as np


def main():
m = librosa.filters.mel(sr=16000, n_fft=400, n_mels=80)
assert m.shape == (80, 201)
s = "// Auto-generated. Do NOT edit!\n\n"
s += "// Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)\n\n"
s += "\n"
s += "#ifndef KALDIFEAT_CSRC_WHISPER_MEL_BANK_H_\n"
s += "#define KALDIFEAT_CSRC_WHISPER_MEL_BANK_H_\n"
s += "namespace kaldifeat {\n\n"
s += f"constexpr int32_t kWhisperMelRows = {m.shape[0]};\n"
s += f"constexpr int32_t kWhisperMelCols = {m.shape[1]};\n"
s += "\n"
s += "constexpr float kWhisperMelArray[] = {\n"
sep = ""
for i, f in enumerate(m.reshape(-1).tolist()):
s += f"{sep}{f:.8f}"
sep = ", "
if i and i % 7 == 0:
s += ",\n"
sep = ""

s += "};\n\n"
s += "} // namespace kaldifeat\n\n"
s += "#endif // KALDIFEAT_CSRC_WHISPER_MEL_BANK_H_\n"

with open("whisper-mel-bank.h", "w") as f:
f.write(s)


if __name__ == "__main__":
main()
9 changes: 9 additions & 0 deletions kaldifeat/csrc/mel-computations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,15 @@ MelBanks::MelBanks(const MelBanksOptions &opts,
}
}

MelBanks::MelBanks(const float *weights, int32_t num_rows, int32_t num_cols,
torch::Device device)
: debug_(false), htk_mode_(false) {
bins_mat_ = torch::from_blob(const_cast<float *>(weights),
{num_rows, num_cols}, torch::kFloat)
.t()
.to(device);
}

torch::Tensor MelBanks::Compute(const torch::Tensor &spectrum) const {
return torch::mm(spectrum, bins_mat_);
}
Expand Down
14 changes: 13 additions & 1 deletion kaldifeat/csrc/mel-computations.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,17 @@ class MelBanks {
const FrameExtractionOptions &frame_opts, float vtln_warp_factor,
torch::Device device);

// Initialize with a 2-d weights matrix
//
// Note: This constructor is for Whisper. It does not initialize
// center_freqs_.
//
// @param weights Pointer to the start address of the matrix
// @param num_rows It equals to number of mel bins
// @param num_cols It equals to (number of fft bins)/2+1
MelBanks(const float *weights, int32_t num_rows, int32_t num_cols,
torch::Device device);

// CAUTION: we save a transposed version of bins_mat_, so return size(1) here
int32_t NumBins() const { return static_cast<int32_t>(bins_mat_.size(1)); }

Expand All @@ -89,7 +100,8 @@ class MelBanks {

private:
// A 2-D matrix. Its shape is NOT [num_bins, num_fft_bins]
// Its shape is [num_fft_bins, num_bins].
// Its shape is [num_fft_bins, num_bins] for non-whisper.
// For whisper, its shape is [num_fft_bins/2+1, num_bins]
torch::Tensor bins_mat_;

// center frequencies of bins, numbered from 0 ... num_bins-1.
Expand Down
78 changes: 78 additions & 0 deletions kaldifeat/csrc/whisper-fbank.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/**
* Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "kaldifeat/csrc/whisper-fbank.h"

#include <cmath>
#include <vector>

#include "kaldifeat/csrc/mel-computations.h"
#include "kaldifeat/csrc/whisper-mel-bank.h"

#ifndef M_2PI
#define M_2PI 6.283185307179586476925286766559005
#endif

namespace kaldifeat {

WhisperFbankComputer::WhisperFbankComputer(const WhisperFbankOptions &opts)
: opts_(opts),
mel_banks_(kWhisperMelArray, kWhisperMelRows, kWhisperMelCols,
opts.device) {
opts_.frame_opts.samp_freq = 16000;
opts_.frame_opts.frame_shift_ms = 10;
opts_.frame_opts.frame_length_ms = 25;
opts_.frame_opts.dither = 0;
opts_.frame_opts.preemph_coeff = 0;
opts_.frame_opts.remove_dc_offset = false;
opts_.frame_opts.window_type = "hann";
opts_.frame_opts.round_to_power_of_two = false;
opts_.frame_opts.snip_edges = false;
}

torch::Tensor WhisperFbankComputer::Compute(
torch::Tensor /*signal_raw_log_energy*/, float /*vtln_warp*/,
const torch::Tensor &signal_frame) {
KALDIFEAT_ASSERT(signal_frame.dim() == 2);
KALDIFEAT_ASSERT(signal_frame.size(1) == opts_.frame_opts.PaddedWindowSize());

// note spectrum is in magnitude, not power, because of `abs()`
#if defined(KALDIFEAT_HAS_FFT_NAMESPACE)
// signal_frame shape: [x, 512]
// power shape [x, 257]
torch::Tensor power = torch::fft::rfft(signal_frame).abs().pow(2);
#else
// signal_frame shape [x, 512]
// real_imag shape [x, 257, 2],
// where [..., 0] is the real part
// [..., 1] is the imaginary part
torch::Tensor real_imag = torch::rfft(signal_frame, 1);
torch::Tensor real = real_imag.index({"...", 0});
torch::Tensor imag = real_imag.index({"...", 1});
torch::Tensor power = (real.square() + imag.square());
#endif

torch::Tensor mel_energies = mel_banks_.Compute(power);
torch::Tensor log_spec = torch::clamp_min(mel_energies, 1e-10).log10();
log_spec = torch::maximum(log_spec, log_spec.max() - 8.0);
torch::Tensor mel = (log_spec + 4.0) / 4.0;

return mel;
}

} // namespace kaldifeat
74 changes: 74 additions & 0 deletions kaldifeat/csrc/whisper-fbank.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/**
* Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef KALDIFEAT_CSRC_WHISPER_FBANK_H_
#define KALDIFEAT_CSRC_WHISPER_FBANK_H_

#include <string>
#include <vector>

#include "kaldifeat/csrc/feature-common.h"
#include "kaldifeat/csrc/feature-window.h"
#include "kaldifeat/csrc/mel-computations.h"

namespace kaldifeat {

struct WhisperFbankOptions {
FrameExtractionOptions frame_opts;

torch::Device device{"cpu"};
std::string ToString() const {
std::ostringstream os;
os << "WhisperFbankOptions(";
os << "frame_opts=" << frame_opts.ToString() << ", ";
os << "device=\"" << device << "\")";
return os.str();
}
};

class WhisperFbankComputer {
public:
// note: Only frame_opts.device is used. All other fields from frame_opts
// are ignored
explicit WhisperFbankComputer(const WhisperFbankOptions &opts = {});

int32_t Dim() const { return 80; }

const FrameExtractionOptions &GetFrameOptions() const {
return opts_.frame_opts;
}

const WhisperFbankOptions &GetOptions() const { return opts_; }

torch::Tensor Compute(torch::Tensor /*signal_raw_log_energy*/,
float /*vtln_warp*/, const torch::Tensor &signal_frame);

// if true, compute log_energy_pre_window but after dithering and dc removal
bool NeedRawLogEnergy() const { return false; }
using Options = WhisperFbankOptions;

private:
WhisperFbankOptions opts_;
MelBanks mel_banks_;
};

using WhisperFbank = OfflineFeatureTpl<WhisperFbankComputer>;

} // namespace kaldifeat

#endif // KALDIFEAT_CSRC_WHISPER_FBANK_H_
Loading
Loading