-
Notifications
You must be signed in to change notification settings - Fork 34
/
mel-computations.h
144 lines (114 loc) · 4.87 KB
/
mel-computations.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
// kaldifeat/csrc/mel-computations.h
//
// Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang)
//
// This file is copied/modified from kaldi/src/feat/mel-computations.h
#include <cmath>
#include <string>
#include "kaldifeat/csrc/feature-window.h"
#ifndef KALDIFEAT_CSRC_MEL_COMPUTATIONS_H_
#define KALDIFEAT_CSRC_MEL_COMPUTATIONS_H_
namespace kaldifeat {
struct MelBanksOptions {
int32_t num_bins = 25; // e.g. 25; number of triangular bins
float low_freq = 20; // e.g. 20; lower frequency cutoff
// an upper frequency cutoff; 0 -> no cutoff, negative
// ->added to the Nyquist frequency to get the cutoff.
float high_freq = 0;
float vtln_low = 100; // vtln lower cutoff of warping function.
// vtln upper cutoff of warping function: if negative, added
// to the Nyquist frequency to get the cutoff.
float vtln_high = -500;
bool debug_mel = false;
// htk_mode is a "hidden" config, it does not show up on command line.
// Enables more exact compatibility with HTK, for testing purposes. Affects
// mel-energy flooring and reproduces a bug in HTK.
bool htk_mode = false;
std::string ToString() const {
std::ostringstream os;
os << "MelBanksOptions(";
os << "num_bins=" << num_bins << ", ";
os << "low_freq=" << low_freq << ", ";
os << "high_freq=" << high_freq << ", ";
os << "vtln_low=" << vtln_low << ", ";
os << "vtln_high=" << vtln_high << ", ";
os << "debug_mel=" << (debug_mel ? "True" : "False") << ", ";
os << "htk_mode=" << (htk_mode ? "True" : "False") << ")";
return os.str();
}
};
std::ostream &operator<<(std::ostream &os, const MelBanksOptions &opts);
class MelBanks {
public:
static inline float InverseMelScale(float mel_freq) {
return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f);
}
static inline float MelScale(float freq) {
return 1127.0f * logf(1.0f + freq / 700.0f);
}
static float VtlnWarpFreq(
float vtln_low_cutoff,
float vtln_high_cutoff, // discontinuities in warp func
float low_freq,
float high_freq, // upper+lower frequency cutoffs in
// the mel computation
float vtln_warp_factor, float freq);
static float VtlnWarpMelFreq(float vtln_low_cutoff, float vtln_high_cutoff,
float low_freq, float high_freq,
float vtln_warp_factor, float mel_freq);
MelBanks(const MelBanksOptions &opts,
const FrameExtractionOptions &frame_opts, float vtln_warp_factor,
torch::Device device);
// Initialize with a 2-d weights matrix
//
// Note: This constructor is for Whisper. It does not initialize
// center_freqs_.
//
// @param weights Pointer to the start address of the matrix
// @param num_rows It equals to number of mel bins
// @param num_cols It equals to (number of fft bins)/2+1
MelBanks(const float *weights, int32_t num_rows, int32_t num_cols,
torch::Device device);
// CAUTION: we save a transposed version of bins_mat_, so return size(1) here
int32_t NumBins() const { return static_cast<int32_t>(bins_mat_.size(1)); }
// returns vector of central freq of each bin; needed by plp code.
const torch::Tensor &GetCenterFreqs() const { return center_freqs_; }
torch::Tensor Compute(const torch::Tensor &spectrum) const;
// for debug only
const torch::Tensor &GetBinsMat() const { return bins_mat_; }
private:
// A 2-D matrix. Its shape is NOT [num_bins, num_fft_bins]
// Its shape is [num_fft_bins, num_bins] for non-whisper.
// For whisper, its shape is [num_fft_bins/2+1, num_bins]
torch::Tensor bins_mat_;
// center frequencies of bins, numbered from 0 ... num_bins-1.
// Needed by GetCenterFreqs().
torch::Tensor center_freqs_; // It's always on CPU
bool debug_;
bool htk_mode_;
};
// Compute liftering coefficients (scaling on cepstral coeffs)
// coeffs are numbered slightly differently from HTK: the zeroth
// index is C0, which is not affected.
//
// coeffs is a 1-D float tensor
void ComputeLifterCoeffs(float Q, torch::Tensor *coeffs);
void GetEqualLoudnessVector(const MelBanks &mel_banks, torch::Tensor *ans);
/* Compute LP coefficients from autocorrelation coefficients.
*
* @param [in] autocorr_in A 2-D tensor. Each row is a frame. Its number of
* columns is lpc_order + 1
* @param [out] lpc_coeffs A 2-D tensor. On return, it has as many rows as the
* input tensor. Its number of columns is lpc_order.
*
* @return Returns log energy of residual in a 1-D tensor. It has as many
* elements as the number of rows in `autocorr_in`.
*/
torch::Tensor ComputeLpc(const torch::Tensor &autocorr_in,
torch::Tensor *lpc_coeffs);
/*
* @param [in] lpc It is the output argument `lpc_coeffs` in ComputeLpc().
*/
torch::Tensor Lpc2Cepstrum(const torch::Tensor &lpc);
} // namespace kaldifeat
#endif // KALDIFEAT_CSRC_MEL_COMPUTATIONS_H_