Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

draft: Add a different struct with a chunk_size value #11

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
123 changes: 123 additions & 0 deletions src/dynamic_vad.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
use ort::{GraphOptimizationLevel, Session};

use crate::{error::Error, Sample};

/// A voice activity detector session.
#[derive(Debug)]
pub struct DynamicVoiceActivityDetector {
chunk_size: usize,
sample_rate: i64,
session: ort::Session,
h: ndarray::Array3<f32>,
c: ndarray::Array3<f32>,
}

/// The silero ONNX model as bytes.
const MODEL: &[u8] = include_bytes!("silero_vad.onnx");

impl DynamicVoiceActivityDetector {
/// Creates a new [DynamicVoiceActivityDetector].
pub fn try_with_sample_rate(
chunk_size: impl Into<usize>,
sample_rate: impl Into<i64>,
) -> Result<Self, Error> {
let chunk_size = chunk_size.into();
let sample_rate: i64 = sample_rate.into();
if (sample_rate as f32) / (chunk_size as f32) > 31.25 {
return Err(Error::VadConfigError {
sample_rate,
chunk_size: chunk_size,
});
}

let session = Session::builder()
.unwrap()
.with_optimization_level(GraphOptimizationLevel::Level3)
.unwrap()
.with_intra_threads(1)
.unwrap()
.with_inter_threads(1)
.unwrap()
.commit_from_memory(MODEL)
.unwrap();

Ok(Self::with_session(chunk_size, session, sample_rate))
}

/// Creates a new [VoiceActivityDetector] using the provided ONNX runtime session.
///
/// Use this if the default ONNX session configuration is not to your liking.
pub fn with_session(
chunk_size: impl Into<usize>,
session: Session,
sample_rate: impl Into<i64>,
) -> Self {
Self {
chunk_size: chunk_size.into(),
session,
sample_rate: sample_rate.into(),
h: ndarray::Array3::<f32>::zeros((2, 1, 64)),
c: ndarray::Array3::<f32>::zeros((2, 1, 64)),
}
}

/// Resets the state of the voice activity detector session.
pub fn reset(&mut self) {
self.h.fill(0f32);
self.c.fill(0f32);
}

/// Predicts the existence of speech in a single iterable of audio.
///
/// The samples iterator will be padded if it is too short, or truncated if it is
/// too long.
pub fn predict<S, I>(&mut self, samples: I) -> f32
where
S: Sample,
I: IntoIterator<Item = S>,
{
let mut input = ndarray::Array2::<f32>::zeros((1, self.chunk_size));
for (i, sample) in samples.into_iter().take(self.chunk_size).enumerate() {
input[[0, i]] = sample.to_f32();
}

let sample_rate = ndarray::arr1::<i64>(&[self.sample_rate]);

let inputs = ort::inputs![
"input" => input.view(),
"sr" => sample_rate.view(),
"h" => self.h.view(),
"c" => self.c.view(),
]
.unwrap();

let outputs = self.session.run(inputs).unwrap();

// Update h and c recursively.
let hn = outputs
.get("hn")
.unwrap()
.try_extract_tensor::<f32>()
.unwrap();
let cn = outputs
.get("cn")
.unwrap()
.try_extract_tensor::<f32>()
.unwrap();

self.h.assign(&hn.view());
self.c.assign(&cn.view());

// Get the probability of speech.
let output = outputs
.get("output")
.unwrap()
.try_extract_tensor::<f32>()
.unwrap();
let probability = output.view()[[0, 0]];

probability
}

// The predict_array function no longer works, because the length of the array can't be trusted
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![warn(missing_docs)]
#![doc = include_str!("../README.md")]

mod dynamic_vad;
mod error;
mod iterator;
mod label;
Expand All @@ -10,6 +11,7 @@ mod sample;
mod stream;
mod vad;

pub use dynamic_vad::DynamicVoiceActivityDetector;
pub use error::Error;
pub use iterator::{IteratorExt, LabelIterator, PredictIterator};
pub use label::LabeledAudio;
Expand Down
48 changes: 48 additions & 0 deletions tests/file_dynamic_vad.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use std::error::Error;

use itertools::Itertools;
use voice_activity_detector::DynamicVoiceActivityDetector;

#[test]
fn wave_file_dynamic_detector() -> Result<(), Box<dyn Error>> {
std::fs::create_dir_all("tests/.outputs")?;

let mut reader = hound::WavReader::open("tests/samples/sample.wav")?;
let spec = reader.spec();

let mut speech = hound::WavWriter::create("tests/.outputs/dynamic.iter.speech.wav", spec)?;
let mut nonspeech =
hound::WavWriter::create("tests/.outputs/dynamic.iter.nonspeech.wav", spec)?;

let mut vad = DynamicVoiceActivityDetector::try_with_sample_rate(256usize, spec.sample_rate)?;

let samples = reader
.samples::<i16>()
.map(|s| s.unwrap())
.collect::<Vec<_>>();
let chunks = samples
.iter()
.chunks(256)
.into_iter()
.map(|chunk| chunk.collect::<Vec<_>>())
.collect::<Vec<_>>();

for chunk in chunks {
let ready_chunk = chunk.iter().map(|s| **s).collect::<Vec<_>>();
let probability = vad.predict(ready_chunk);
if probability > 0.5 {
for sample in chunk {
speech.write_sample(*sample)?;
}
} else {
for sample in chunk {
nonspeech.write_sample(*sample)?;
}
}
}

speech.finalize()?;
nonspeech.finalize()?;

Ok(())
}
46 changes: 25 additions & 21 deletions tests/file_label_stream.rs
Original file line number Diff line number Diff line change
@@ -1,37 +1,41 @@
#[cfg(feature = "async")]
use tokio_stream::StreamExt;
#[cfg(feature = "async")]
use voice_activity_detector::{StreamExt as _, VoiceActivityDetector};

#[tokio::test]
async fn wave_file_label_iterator() -> Result<(), Box<dyn std::error::Error>> {
std::fs::create_dir_all("tests/.outputs")?;
#[cfg(feature = "async")]
{
std::fs::create_dir_all("tests/.outputs")?;

let mut reader = hound::WavReader::open("tests/samples/sample.wav")?;
let spec = reader.spec();
let mut reader = hound::WavReader::open("tests/samples/sample.wav")?;
let spec = reader.spec();

let mut speech = hound::WavWriter::create("tests/.outputs/label.stream.speech.wav", spec)?;
let mut nonspeech =
hound::WavWriter::create("tests/.outputs/label.stream.nonspeech.wav", spec)?;
let mut speech = hound::WavWriter::create("tests/.outputs/label.stream.speech.wav", spec)?;
let mut nonspeech =
hound::WavWriter::create("tests/.outputs/label.stream.nonspeech.wav", spec)?;

let vad = VoiceActivityDetector::<256>::try_with_sample_rate(spec.sample_rate)?;
let vad = VoiceActivityDetector::<256>::try_with_sample_rate(spec.sample_rate)?;

let chunks = reader.samples::<i16>().map_while(Result::ok);
let chunks = reader.samples::<i16>().map_while(Result::ok);

let mut chunks = tokio_stream::iter(chunks).label(vad, 0.5, 10);
let mut chunks = tokio_stream::iter(chunks).label(vad, 0.5, 10);

while let Some(chunk) = chunks.next().await {
if chunk.is_speech() {
for sample in chunk {
speech.write_sample(sample)?;
}
} else {
for sample in chunk {
nonspeech.write_sample(sample)?;
while let Some(chunk) = chunks.next().await {
if chunk.is_speech() {
for sample in chunk {
speech.write_sample(sample)?;
}
} else {
for sample in chunk {
nonspeech.write_sample(sample)?;
}
}
}
}

speech.finalize()?;
nonspeech.finalize()?;

speech.finalize()?;
nonspeech.finalize()?;
}
Ok(())
}
61 changes: 33 additions & 28 deletions tests/file_predict_stream.rs
Original file line number Diff line number Diff line change
@@ -1,36 +1,41 @@
use tokio_stream::{self, StreamExt};
#[cfg(feature = "async")]
use tokio_stream::StreamExt;
#[cfg(feature = "async")]
use voice_activity_detector::{StreamExt as _, VoiceActivityDetector};

#[tokio::test]
async fn wave_file_predict_stream() -> Result<(), Box<dyn std::error::Error>> {
std::fs::create_dir_all("tests/.outputs")?;

let mut reader = hound::WavReader::open("tests/samples/sample.wav")?;
let spec = reader.spec();

let mut speech = hound::WavWriter::create("tests/.outputs/predict.stream.speech.wav", spec)?;
let mut nonspeech =
hound::WavWriter::create("tests/.outputs/predict.stream.nonspeech.wav", spec)?;

let vad = VoiceActivityDetector::<256>::try_with_sample_rate(spec.sample_rate)?;

let chunks = reader.samples::<i16>().map_while(Result::ok);
let mut chunks = tokio_stream::iter(chunks).predict(vad);

while let Some((chunk, probability)) = chunks.next().await {
if probability > 0.5 {
for sample in chunk {
speech.write_sample(sample)?;
}
} else {
for sample in chunk {
nonspeech.write_sample(sample)?;
async fn wave_file_predict_stream() -> Result<(), Box<dyn std::error::Error>> {
#[cfg(feature = "async")]
{
std::fs::create_dir_all("tests/.outputs")?;

let mut reader = hound::WavReader::open("tests/samples/sample.wav")?;
let spec = reader.spec();

let mut speech =
hound::WavWriter::create("tests/.outputs/predict.stream.speech.wav", spec)?;
let mut nonspeech =
hound::WavWriter::create("tests/.outputs/predict.stream.nonspeech.wav", spec)?;

let vad = VoiceActivityDetector::<256>::try_with_sample_rate(spec.sample_rate)?;

let chunks = reader.samples::<i16>().map_while(Result::ok);
let mut chunks = tokio_stream::iter(chunks).predict(vad);

while let Some((chunk, probability)) = chunks.next().await {
if probability > 0.5 {
for sample in chunk {
speech.write_sample(sample)?;
}
} else {
for sample in chunk {
nonspeech.write_sample(sample)?;
}
}
}
}

speech.finalize()?;
nonspeech.finalize()?;

speech.finalize()?;
nonspeech.finalize()?;
}
Ok(())
}