Skip to content

Commit

Permalink
Fix bug in FFT
Browse files Browse the repository at this point in the history
The FFT routine does not work for odd N
Solution is to add DFT and use it when N is odd
  • Loading branch information
ggerganov committed Oct 2, 2022
1 parent 6d654d1 commit 77d929f
Showing 1 changed file with 42 additions and 2 deletions.
44 changes: 42 additions & 2 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1909,8 +1909,31 @@ whisper_vocab::id whisper_sample_timestamp(
return probs_id[0].second;
}

// naive Discrete Fourier Transform
// input is real-valued
// output is complex-valued
void dft(const std::vector<float> & in, std::vector<float> & out) {
int N = in.size();

out.resize(N*2);

for (int k = 0; k < N; k++) {
float re = 0;
float im = 0;

for (int n = 0; n < N; n++) {
float angle = 2*M_PI*k*n/N;
re += in[n]*cos(angle);
im -= in[n]*sin(angle);
}

out[k*2 + 0] = re;
out[k*2 + 1] = im;
}
}

// Cooley-Tukey FFT
// poor man's implmentation - use something better
// poor man's implementation - use something better
// input is real-valued
// output is complex-valued
void fft(const std::vector<float> & in, std::vector<float> & out) {
Expand All @@ -1924,6 +1947,11 @@ void fft(const std::vector<float> & in, std::vector<float> & out) {
return;
}

if (N%2 == 1) {
dft(in, out);
return;
}

std::vector<float> even;
std::vector<float> odd;

Expand Down Expand Up @@ -2014,9 +2042,20 @@ bool log_mel_spectrogram(
// FFT -> mag^2
fft(fft_in, fft_out);

for (int j = 0; j < n_fft; j++) {
for (int j = 0; j < fft_size; j++) {
fft_out[j] = (fft_out[2*j + 0]*fft_out[2*j + 0] + fft_out[2*j + 1]*fft_out[2*j + 1]);
}
for (int j = 1; j < fft_size/2; j++) {
//if (i == 0) {
// printf("%d: %f %f\n", j, fft_out[j], fft_out[fft_size - j]);
//}
fft_out[j] += fft_out[fft_size - j];
}
if (i == 0) {
//for (int j = 0; j < fft_size; j++) {
// printf("%d: %e\n", j, fft_out[j]);
//}
}

// mel spectrogram
for (int j = 0; j < mel.n_mel; j++) {
Expand Down Expand Up @@ -2048,6 +2087,7 @@ bool log_mel_spectrogram(
mmax = mel.data[i];
}
}
//printf("%s: max = %f\n", __func__, mmax);

mmax -= 8.0;

Expand Down

0 comments on commit 77d929f

Please sign in to comment.