diff --git a/crates/audio/src/lib.rs b/crates/audio/src/lib.rs index e7b37f49e..374c244a8 100644 --- a/crates/audio/src/lib.rs +++ b/crates/audio/src/lib.rs @@ -2,12 +2,14 @@ mod device_monitor; mod errors; mod mic; mod norm; +mod resampler; mod speaker; pub use device_monitor::*; pub use errors::*; pub use mic::*; pub use norm::*; +pub use resampler::*; pub use speaker::*; pub use cpal; diff --git a/crates/audio/src/resampler.rs b/crates/audio/src/resampler.rs new file mode 100644 index 000000000..b1bf3d3dc --- /dev/null +++ b/crates/audio/src/resampler.rs @@ -0,0 +1,206 @@ +use dasp::interpolate::Interpolator; +use futures_util::Stream; +use kalosm_sound::AsyncSource; + +pub struct ResampledAsyncSource { + source: S, + target_sample_rate: u32, + sample_position: f64, + resampler: dasp::interpolate::linear::Linear, + last_source_rate: u32, +} + +impl ResampledAsyncSource { + pub fn new(source: S, target_sample_rate: u32) -> Self { + let initial_rate = source.sample_rate(); + Self { + source, + target_sample_rate, + sample_position: initial_rate as f64 / target_sample_rate as f64, + resampler: dasp::interpolate::linear::Linear::new(0.0, 0.0), + last_source_rate: initial_rate, + } + } +} + +impl Stream for ResampledAsyncSource { + type Item = f32; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + let myself = self.get_mut(); + + let current_source_rate = myself.source.sample_rate(); + if current_source_rate != myself.last_source_rate { + myself.last_source_rate = current_source_rate; + } + + let source_output_sample_ratio = + current_source_rate as f64 / myself.target_sample_rate as f64; + + let source = myself.source.as_stream(); + let mut source = std::pin::pin!(source); + + while myself.sample_position >= 1.0 { + match source.as_mut().poll_next(cx) { + std::task::Poll::Ready(Some(frame)) => { + myself.sample_position -= 1.0; + myself.resampler.next_source_frame(frame); + } + std::task::Poll::Ready(None) => return std::task::Poll::Ready(None), + std::task::Poll::Pending => return std::task::Poll::Pending, + } + } + + let interpolated = myself.resampler.interpolate(myself.sample_position); + myself.sample_position += source_output_sample_ratio; + + std::task::Poll::Ready(Some(interpolated)) + } +} + +impl AsyncSource for ResampledAsyncSource { + fn as_stream(&mut self) -> impl Stream + '_ { + self + } + + fn sample_rate(&self) -> u32 { + self.target_sample_rate + } +} + +#[cfg(test)] +mod tests { + use futures_util::{Stream, StreamExt}; + use kalosm_sound::AsyncSource; + use rodio::Source; + use std::pin::Pin; + use std::task::{Context, Poll}; + + use crate::ResampledAsyncSource; + + fn get_samples_with_rate(path: impl AsRef) -> (Vec, u32) { + let source = + rodio::Decoder::new(std::io::BufReader::new(std::fs::File::open(path).unwrap())) + .unwrap(); + + let sample_rate = AsyncSource::sample_rate(&source); + let samples = source.convert_samples::().collect(); + (samples, sample_rate) + } + + struct DynamicRateSource { + segments: Vec<(Vec, u32)>, + current_segment: usize, + current_position: usize, + } + + impl DynamicRateSource { + fn new(segments: Vec<(Vec, u32)>) -> Self { + Self { + segments, + current_segment: 0, + current_position: 0, + } + } + } + + impl AsyncSource for DynamicRateSource { + fn as_stream(&mut self) -> impl Stream + '_ { + DynamicRateStream { source: self } + } + + fn sample_rate(&self) -> u32 { + if self.current_segment < self.segments.len() { + self.segments[self.current_segment].1 + } else { + unreachable!() + } + } + } + + struct DynamicRateStream<'a> { + source: &'a mut DynamicRateSource, + } + + impl<'a> Stream for DynamicRateStream<'a> { + type Item = f32; + + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + let source = &mut self.source; + + while source.current_segment < source.segments.len() { + let (samples, _rate) = &source.segments[source.current_segment]; + + if source.current_position < samples.len() { + let sample = samples[source.current_position]; + source.current_position += 1; + return Poll::Ready(Some(sample)); + } + + source.current_segment += 1; + source.current_position = 0; + } + + Poll::Ready(None) + } + } + + #[tokio::test] + async fn test_existing_resampler() { + let source = DynamicRateSource::new(vec![ + get_samples_with_rate(hypr_data::english_1::AUDIO_PART1_8000HZ_PATH), + get_samples_with_rate(hypr_data::english_1::AUDIO_PART2_16000HZ_PATH), + get_samples_with_rate(hypr_data::english_1::AUDIO_PART3_22050HZ_PATH), + get_samples_with_rate(hypr_data::english_1::AUDIO_PART4_32000HZ_PATH), + get_samples_with_rate(hypr_data::english_1::AUDIO_PART5_44100HZ_PATH), + get_samples_with_rate(hypr_data::english_1::AUDIO_PART6_48000HZ_PATH), + ]); + + let mut out_wav = hound::WavWriter::create( + "./out_1.wav", + hound::WavSpec { + channels: 1, + sample_rate: 16000, + bits_per_sample: 32, + sample_format: hound::SampleFormat::Float, + }, + ) + .unwrap(); + + let mut resampled = source.resample(16000); + while let Some(sample) = resampled.next().await { + out_wav.write_sample(sample).unwrap(); + } + } + + #[tokio::test] + async fn test_new_resampler() { + let source = DynamicRateSource::new(vec![ + get_samples_with_rate(hypr_data::english_1::AUDIO_PART1_8000HZ_PATH), + get_samples_with_rate(hypr_data::english_1::AUDIO_PART2_16000HZ_PATH), + get_samples_with_rate(hypr_data::english_1::AUDIO_PART3_22050HZ_PATH), + get_samples_with_rate(hypr_data::english_1::AUDIO_PART4_32000HZ_PATH), + get_samples_with_rate(hypr_data::english_1::AUDIO_PART5_44100HZ_PATH), + get_samples_with_rate(hypr_data::english_1::AUDIO_PART6_48000HZ_PATH), + ]); + + let mut out_wav = hound::WavWriter::create( + "./out_2.wav", + hound::WavSpec { + channels: 1, + sample_rate: 16000, + bits_per_sample: 32, + sample_format: hound::SampleFormat::Float, + }, + ) + .unwrap(); + + let mut resampled = ResampledAsyncSource::new(source, 16000); + while let Some(sample) = resampled.next().await { + out_wav.write_sample(sample).unwrap(); + } + } +} diff --git a/crates/audio/src/speaker/macos.rs b/crates/audio/src/speaker/macos.rs index 1903fac28..4c5fa83f1 100644 --- a/crates/audio/src/speaker/macos.rs +++ b/crates/audio/src/speaker/macos.rs @@ -1,3 +1,4 @@ +use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::{Arc, Mutex}; use std::task::{Poll, Waker}; @@ -25,16 +26,16 @@ struct WakerState { pub struct SpeakerStream { consumer: HeapCons, - stream_desc: cat::AudioStreamBasicDesc, _device: ca::hardware::StartedDevice, _ctx: Box, _tap: ca::TapGuard, waker_state: Arc>, + current_sample_rate: Arc, } impl SpeakerStream { pub fn sample_rate(&self) -> u32 { - self.stream_desc.sample_rate as u32 + self.current_sample_rate.load(Ordering::Relaxed) } } @@ -42,6 +43,7 @@ struct Ctx { format: arc::R, producer: HeapProd, waker_state: Arc>, + current_sample_rate: Arc, } impl SpeakerInput { @@ -100,7 +102,7 @@ impl SpeakerInput { ctx: &mut Box, ) -> Result> { extern "C" fn proc( - _device: ca::Device, + device: ca::Device, _now: &cat::AudioTimeStamp, input_data: &cat::AudioBufList<1>, _input_time: &cat::AudioTimeStamp, @@ -110,6 +112,13 @@ impl SpeakerInput { ) -> os::Status { let ctx = ctx.unwrap(); + ctx.current_sample_rate.store( + device + .actual_sample_rate() + .unwrap_or(ctx.format.absd().sample_rate) as u32, + Ordering::Relaxed, + ); + assert_eq!(ctx.format.common_format(), av::audio::CommonFormat::PcmF32); if let Some(view) = @@ -157,21 +166,24 @@ impl SpeakerInput { has_data: false, })); + let current_sample_rate = Arc::new(AtomicU32::new(asbd.sample_rate as u32)); + let mut ctx = Box::new(Ctx { format, producer, waker_state: waker_state.clone(), + current_sample_rate: current_sample_rate.clone(), }); let device = self.start_device(&mut ctx).unwrap(); SpeakerStream { consumer, - stream_desc: asbd, _device: device, _ctx: ctx, _tap: self.tap, waker_state, + current_sample_rate, } } } diff --git a/crates/data/scripts/resamples.sh b/crates/data/scripts/resamples.sh new file mode 100644 index 000000000..6b1be5edc --- /dev/null +++ b/crates/data/scripts/resamples.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# Check if input file is provided +if [ $# -eq 0 ]; then + echo "Usage: $0 " + exit 1 +fi + +INPUT_FILE="$1" +DIR=$(dirname "$INPUT_FILE") +BASENAME=$(basename "$INPUT_FILE" .wav) + +# Array of common sample rates for testing +SAMPLE_RATES=(8000 16000 22050 32000 44100 48000) + +# Get duration of input file in seconds +DURATION=$(ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 "$INPUT_FILE") + +# Calculate part duration +NUM_PARTS=${#SAMPLE_RATES[@]} +PART_DURATION=$(echo "$DURATION / $NUM_PARTS" | bc -l) + +# Generate parts with different sample rates +for i in "${!SAMPLE_RATES[@]}"; do + RATE=${SAMPLE_RATES[$i]} + PART_NUM=$((i + 1)) + START=$(echo "$i * $PART_DURATION" | bc -l) + OUTPUT_FILE="${DIR}/${BASENAME}_part${PART_NUM}_${RATE}hz.wav" + + echo "Creating part ${PART_NUM}: ${START}s-$(echo "$START + $PART_DURATION" | bc -l)s at ${RATE}Hz" + ffmpeg -i "$INPUT_FILE" -ss ${START} -t ${PART_DURATION} -ar ${RATE} "$OUTPUT_FILE" -y -loglevel error +done + +echo "Done! Created ${NUM_PARTS} parts with different sample rates in ${DIR}/" \ No newline at end of file diff --git a/crates/data/src/english_1/audio_part1_8000hz.wav b/crates/data/src/english_1/audio_part1_8000hz.wav new file mode 100644 index 000000000..b2ae29ce1 Binary files /dev/null and b/crates/data/src/english_1/audio_part1_8000hz.wav differ diff --git a/crates/data/src/english_1/audio_part2_16000hz.wav b/crates/data/src/english_1/audio_part2_16000hz.wav new file mode 100644 index 000000000..42f707908 Binary files /dev/null and b/crates/data/src/english_1/audio_part2_16000hz.wav differ diff --git a/crates/data/src/english_1/audio_part3_22050hz.wav b/crates/data/src/english_1/audio_part3_22050hz.wav new file mode 100644 index 000000000..f72060f6c Binary files /dev/null and b/crates/data/src/english_1/audio_part3_22050hz.wav differ diff --git a/crates/data/src/english_1/audio_part4_32000hz.wav b/crates/data/src/english_1/audio_part4_32000hz.wav new file mode 100644 index 000000000..214b6708b Binary files /dev/null and b/crates/data/src/english_1/audio_part4_32000hz.wav differ diff --git a/crates/data/src/english_1/audio_part5_44100hz.wav b/crates/data/src/english_1/audio_part5_44100hz.wav new file mode 100644 index 000000000..707218d0a Binary files /dev/null and b/crates/data/src/english_1/audio_part5_44100hz.wav differ diff --git a/crates/data/src/english_1/audio_part6_48000hz.wav b/crates/data/src/english_1/audio_part6_48000hz.wav new file mode 100644 index 000000000..7ad78758f Binary files /dev/null and b/crates/data/src/english_1/audio_part6_48000hz.wav differ diff --git a/crates/data/src/english_1/mod.rs b/crates/data/src/english_1/mod.rs index 586970626..24f8295da 100644 --- a/crates/data/src/english_1/mod.rs +++ b/crates/data/src/english_1/mod.rs @@ -1,6 +1,31 @@ pub const AUDIO: &[u8] = include_wav!("./audio.wav"); pub const AUDIO_PATH: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/src/english_1/audio.wav"); +pub const AUDIO_PART1_8000HZ_PATH: &str = concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/english_1/audio_part1_8000hz.wav" +); +pub const AUDIO_PART2_16000HZ_PATH: &str = concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/english_1/audio_part2_16000hz.wav" +); +pub const AUDIO_PART3_22050HZ_PATH: &str = concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/english_1/audio_part3_22050hz.wav" +); +pub const AUDIO_PART4_32000HZ_PATH: &str = concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/english_1/audio_part4_32000hz.wav" +); +pub const AUDIO_PART5_44100HZ_PATH: &str = concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/english_1/audio_part5_44100hz.wav" +); +pub const AUDIO_PART6_48000HZ_PATH: &str = concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/english_1/audio_part6_48000hz.wav" +); + pub const TRANSCRIPTION_JSON: &str = include_str!("./transcription.json"); pub const TRANSCRIPTION_PATH: &str = concat!( diff --git a/plugins/listener/src/fsm.rs b/plugins/listener/src/fsm.rs index fab0d05bc..978261b18 100644 --- a/plugins/listener/src/fsm.rs +++ b/plugins/listener/src/fsm.rs @@ -9,7 +9,7 @@ use tauri_specta::Event; use futures_util::StreamExt; use tokio::task::JoinSet; -use hypr_audio::AsyncSource; +use hypr_audio::ResampledAsyncSource; use crate::{manager::TranscriptManager, SessionEvent}; @@ -249,18 +249,11 @@ impl Session { let mut input = hypr_audio::AudioInput::from_mic(self.mic_device_name.clone())?; input.stream() }; - let mic_stream = mic_sample_stream - .resample(SAMPLE_RATE) - .chunks(hypr_aec::BLOCK_SIZE); - - // https://github.com/fastrepl/hyprnote/commit/7c8cf1c - tokio::time::sleep(Duration::from_millis(65)).await; - // We need some delay here for Airpod transition. - // But if the delay is too long, AEC will not work. + let mic_stream = + ResampledAsyncSource::new(mic_sample_stream, SAMPLE_RATE).chunks(hypr_aec::BLOCK_SIZE); let speaker_sample_stream = hypr_audio::AudioInput::from_speaker().stream(); - let speaker_stream = speaker_sample_stream - .resample(SAMPLE_RATE) + let speaker_stream = ResampledAsyncSource::new(speaker_sample_stream, SAMPLE_RATE) .chunks(hypr_aec::BLOCK_SIZE); let channels = AudioChannels::new(); @@ -316,7 +309,10 @@ impl Session { let mic_chunk = match maybe_mic_chunk { Ok(mic_chunk) => mic_chunk, - Err(_) => mic_chunk_raw, + Err(e) => { + tracing::error!("aec_error: {:?}", e); + mic_chunk_raw + } }; if matches!(*session_state_rx.borrow(), State::RunningPaused {}) { diff --git a/plugins/notification/src/handler.rs b/plugins/notification/src/handler.rs index 17cb96ea1..efd9da23d 100644 --- a/plugins/notification/src/handler.rs +++ b/plugins/notification/src/handler.rs @@ -78,7 +78,7 @@ impl NotificationHandler { hypr_detect::DetectEvent::MicStopped => { use tauri_plugin_listener::ListenerPluginExt; let app_handle = app_handle.clone(); - tokio::spawn(async move { + tauri::async_runtime::spawn(async move { app_handle.pause_session().await; }); }