Skip to content

Commit

Permalink
feat: implement VAD on realtime transcription (#129)
Browse files Browse the repository at this point in the history
* feat(ios): initial work for simple VAD

* feat(ios): skip vad if isTranscribing

* feat(ios): add vadMs / vadThold / vadFreqThold options

* feat(android): implement vad on realtime transcription

* feat: use vad to check last transcription

* feat(example): do not use vad by default
  • Loading branch information
jhen0409 committed Sep 23, 2023
1 parent 61f01e7 commit 965409d
Show file tree
Hide file tree
Showing 8 changed files with 167 additions and 7 deletions.
33 changes: 33 additions & 0 deletions android/src/main/java/com/rnwhisper/WhisperContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,27 @@ private void saveWavFile(byte[] rawData, String audioOutputFile) throws IOExcept
}
}

private boolean vad(ReadableMap options, short[] shortBuffer, int nSamples, int n) {
boolean isSpeech = true;
if (!isTranscribing && options.hasKey("useVad") && options.getBoolean("useVad")) {
int vadSec = options.hasKey("vadMs") ? options.getInt("vadMs") / 1000 : 2;
int sampleSize = vadSec * SAMPLE_RATE;
if (nSamples + n > sampleSize) {
int start = nSamples + n - sampleSize;
float[] audioData = new float[sampleSize];
for (int i = 0; i < sampleSize; i++) {
audioData[i] = shortBuffer[i + start] / 32768.0f;
}
float vadThold = options.hasKey("vadThold") ? (float) options.getDouble("vadThold") : 0.6f;
float vadFreqThold = options.hasKey("vadFreqThold") ? (float) options.getDouble("vadFreqThold") : 0.6f;
isSpeech = vadSimple(audioData, sampleSize, vadThold, vadFreqThold);
} else {
isSpeech = false;
}
}
return isSpeech;
}

public int startRealtimeTranscribe(int jobId, ReadableMap options) {
if (isCapturing || isTranscribing) {
return -100;
Expand Down Expand Up @@ -223,6 +244,12 @@ public void run() {
) {
emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap());
} else if (!isTranscribing) {
short[] shortBuffer = shortBufferSlices.get(sliceIndex);
boolean isSpeech = vad(options, shortBuffer, nSamples, 0);
if (!isSpeech) {
emitTranscribeEvent("@RNWhisper_onRealtimeTranscribeEnd", Arguments.createMap());
break;
}
isTranscribing = true;
fullTranscribeSamples(options, true);
}
Expand All @@ -244,9 +271,14 @@ public void run() {
for (int i = 0; i < n; i++) {
shortBuffer[nSamples + i] = buffer[i];
}

boolean isSpeech = vad(options, shortBuffer, nSamples, n);

nSamples += n;
sliceNSamples.set(sliceIndex, nSamples);

if (!isSpeech) continue;

if (!isTranscribing && nSamples > SAMPLE_RATE / 2) {
isTranscribing = true;
fullHandler = new Thread(new Runnable() {
Expand Down Expand Up @@ -593,6 +625,7 @@ private static String cpuInfo() {
protected static native long initContext(String modelPath);
protected static native long initContextWithAsset(AssetManager assetManager, String modelPath);
protected static native long initContextWithInputStream(PushbackInputStream inputStream);
protected static native boolean vadSimple(float[] audio_data, int audio_data_len, float vad_thold, float vad_freq_thold);
protected static native int fullTranscribe(
int job_id,
long context,
Expand Down
22 changes: 22 additions & 0 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <sys/sysinfo.h>
#include <string>
#include <thread>
#include <vector>
#include "whisper.h"
#include "rn-whisper.h"
#include "ggml.h"
Expand Down Expand Up @@ -184,6 +185,27 @@ Java_com_rnwhisper_WhisperContext_initContextWithInputStream(
return reinterpret_cast<jlong>(context);
}

JNIEXPORT jboolean JNICALL
Java_com_rnwhisper_WhisperContext_vadSimple(
JNIEnv *env,
jobject thiz,
jfloatArray audio_data,
jint audio_data_len,
jfloat vad_thold,
jfloat vad_freq_thold
) {
UNUSED(thiz);

std::vector<float> samples(audio_data_len);
jfloat *audio_data_arr = env->GetFloatArrayElements(audio_data, nullptr);
for (int i = 0; i < audio_data_len; i++) {
samples[i] = audio_data_arr[i];
}
bool is_speech = rn_whisper_vad_simple(samples, WHISPER_SAMPLE_RATE, 1000, vad_thold, vad_freq_thold, false);
env->ReleaseFloatArrayElements(audio_data, audio_data_arr, JNI_ABORT);
return is_speech;
}

struct progress_callback_context {
JNIEnv *env;
jobject progress_callback_instance;
Expand Down
51 changes: 51 additions & 0 deletions cpp/rn-whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,55 @@ void rn_whisper_abort_all_transcribe() {
}
}

void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);

float y = data[0];

for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}

bool rn_whisper_vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;

if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
return false;
}

if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}

float energy_all = 0.0f;
float energy_last = 0.0f;

for (int i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);

if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}

energy_all /= n_samples;
energy_last /= n_samples_last;

if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}

if (energy_last > vad_thold*energy_all) {
return false;
}

return true;
}

}
3 changes: 2 additions & 1 deletion cpp/rn-whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ void rn_whisper_remove_abort_map(int job_id);
void rn_whisper_abort_transcribe(int job_id);
bool rn_whisper_transcribe_is_aborted(int job_id);
void rn_whisper_abort_all_transcribe();
bool rn_whisper_vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose);

#ifdef __cplusplus
}
#endif
#endif
12 changes: 6 additions & 6 deletions example/ios/Podfile.lock
Original file line number Diff line number Diff line change
Expand Up @@ -750,16 +750,16 @@ PODS:
- React-perflogger (= 0.71.11)
- RNFS (2.20.0):
- React-Core
- RNZipArchive (6.0.9):
- RNZipArchive (6.1.0):
- React-Core
- RNZipArchive/Core (= 6.0.9)
- RNZipArchive/Core (= 6.1.0)
- SSZipArchive (~> 2.2)
- RNZipArchive/Core (6.0.9):
- RNZipArchive/Core (6.1.0):
- React-Core
- SSZipArchive (~> 2.2)
- SocketRocket (0.6.0)
- SSZipArchive (2.4.3)
- whisper-rn (0.3.5):
- whisper-rn (0.3.6):
- RCT-Folly
- RCTRequired
- RCTTypeSafety
Expand Down Expand Up @@ -994,10 +994,10 @@ SPEC CHECKSUMS:
React-runtimeexecutor: 4817d63dbc9d658f8dc0ec56bd9b83ce531129f0
ReactCommon: 08723d2ed328c5cbcb0de168f231bc7bae7f8aa1
RNFS: 4ac0f0ea233904cb798630b3c077808c06931688
RNZipArchive: 68a0c6db4b1c103f846f1559622050df254a3ade
RNZipArchive: ef9451b849c45a29509bf44e65b788829ab07801
SocketRocket: fccef3f9c5cedea1353a9ef6ada904fde10d6608
SSZipArchive: fe6a26b2a54d5a0890f2567b5cc6de5caa600aef
whisper-rn: 6f293154b175fee138a994fa00d0f414fb1f44e9
whisper-rn: e80c0482f6a632faafd601f98f10da0255c1e1ec
Yoga: f7decafdc5e8c125e6fa0da38a687e35238420fa
YogaKit: f782866e155069a2cca2517aafea43200b01fd5a

Expand Down
2 changes: 2 additions & 0 deletions example/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,8 @@ export default function App() {
realtimeAudioSec: 60,
// Slice audio into 25 (or < 30) sec chunks for better performance
realtimeAudioSliceSec: 25,
// Voice Activity Detection - Start transcribing when speech is detected
// useVad: true,
})
setStopTranscribe({ stop })
subscribe((evt) => {
Expand Down
33 changes: 33 additions & 0 deletions ios/RNWhisperContext.mm
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#import "RNWhisperContext.h"
#include <vector>

#define NUM_BYTES_PER_BUFFER 16 * 1024

Expand Down Expand Up @@ -77,6 +78,29 @@ - (void)freeBufferIfNeeded {
}
}

bool vad(RNWhisperContextRecordState *state, int16_t* audioBufferI16, int nSamples, int n)
{
bool isSpeech = true;
if (!state->isTranscribing && state->options[@"useVad"]) {
int vadSec = state->options[@"vadMs"] != nil ? [state->options[@"vadMs"] intValue] / 1000 : 2;
int sampleSize = vadSec * WHISPER_SAMPLE_RATE;
if (nSamples + n > sampleSize) {
int start = nSamples + n - sampleSize;
std::vector<float> audioBufferF32Vec(sampleSize);
for (int i = 0; i < sampleSize; i++) {
audioBufferF32Vec[i] = (float)audioBufferI16[i + start] / 32768.0f;
}
float vadThold = state->options[@"vadThold"] != nil ? [state->options[@"vadThold"] floatValue] : 0.6f;
float vadFreqThold = state->options[@"vadFreqThold"] != nil ? [state->options[@"vadFreqThold"] floatValue] : 100.0f;
isSpeech = rn_whisper_vad_simple(audioBufferF32Vec, WHISPER_SAMPLE_RATE, 1000, vadThold, vadFreqThold, false);
NSLog(@"[RNWhisper] VAD result: %d", isSpeech);
} else {
isSpeech = false;
}
}
return isSpeech;
}

void AudioInputCallback(void * inUserData,
AudioQueueRef inAQ,
AudioQueueBufferRef inBuffer,
Expand Down Expand Up @@ -117,6 +141,11 @@ void AudioInputCallback(void * inUserData,
!state->isTranscribing &&
nSamples != state->nSamplesTranscribing
) {
int16_t* audioBufferI16 = (int16_t*) [state->shortBufferSlices[state->sliceIndex] pointerValue];
if (!vad(state, audioBufferI16, nSamples, 0)) {
state->transcribeHandler(state->jobId, @"end", @{});
return;
}
state->isTranscribing = true;
dispatch_async([state->mSelf getDispatchQueue], ^{
[state->mSelf fullTranscribeSamples:state];
Expand All @@ -142,11 +171,15 @@ void AudioInputCallback(void * inUserData,
for (int i = 0; i < n; i++) {
audioBufferI16[nSamples + i] = ((short*)inBuffer->mAudioData)[i];
}

bool isSpeech = vad(state, audioBufferI16, nSamples, n);
nSamples += n;
state->sliceNSamples[state->sliceIndex] = [NSNumber numberWithInt:nSamples];

AudioQueueEnqueueBuffer(state->queue, inBuffer, 0, NULL);

if (!isSpeech) return;

if (!state->isTranscribing) {
state->isTranscribing = true;
dispatch_async([state->mSelf getDispatchQueue], ^{
Expand Down
18 changes: 18 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,24 @@ export type TranscribeRealtimeOptions = TranscribeOptions & {
* (Default: Undefined)
*/
audioOutputPath?: string
/**
* Start transcribe on recording when the audio volume is greater than the threshold by using VAD (Voice Activity Detection).
* The first VAD will be triggered after 2 second of recording.
* (Default: false)
*/
useVad?: boolean
/**
* The length of the collected audio is used for VAD. (ms) (Default: 2000)
*/
vadMs?: number
/**
* VAD threshold. (Default: 0.6)
*/
vadThold?: number
/**
* Frequency to apply High-pass filter in VAD. (Default: 100.0)
*/
vadFreqThold?: number
}

export type TranscribeRealtimeEvent = {
Expand Down

0 comments on commit 965409d

Please sign in to comment.