diff --git a/.cursor/rules/simple.mdc b/.cursor/rules/simple.mdc index 92a5a10d9..509795c67 100644 --- a/.cursor/rules/simple.mdc +++ b/.cursor/rules/simple.mdc @@ -9,6 +9,11 @@ alwaysApply: true # Typescript - Avoid creating a bunch of types/interfaces if they are not shared. Especially for function props. Just inline them. +- After some amount of TypeScript changes, run `pnpm -r typecheck`. + +# Rust + +- After some amount of Rust changes, run `cargo check`. # Mutation - Never do manual state management for form/mutation. Things like setError is anti-pattern. use useForm(from tanstack-form) and useQuery/useMutation(from tanstack-query) for 99% cases. @@ -19,7 +24,6 @@ alwaysApply: true # Misc - Do not create summary docs or example code file if not requested. Plan is ok. -- After a significant amount of TypeScript changes, run `pnpm -r typecheck`. - If there are many classNames and they have conditional logic, use `cn` (import it with `import { cn } from "@hypr/utils"`). It is similar to `clsx`. Always pass an array. Split by logical grouping. - Use `motion/react` instead of `framer-motion`. diff --git a/Cargo.lock b/Cargo.lock index f9709850b..e70ba138a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7492,6 +7492,12 @@ dependencies = [ "icu_properties", ] +[[package]] +name = "if_chain" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd62e6b5e86ea8eeeb8db1de02880a6abc01a397b2ebb64b5d74ac255318f5cb" + [[package]] name = "ignore" version = "0.4.25" @@ -11411,22 +11417,32 @@ checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] name = "ractor" -version = "0.15.9" +version = "0.14.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9500e0be6f12a0539cb1154d654ef2e888bf8529164e54aff4a097baad5bb001" +checksum = "1d65972a0286ef14c43c6daafbac6cf15e96496446147683b2905292c35cc178" dependencies = [ + "async-trait", "bon 2.3.0", "dashmap", "futures", - "js-sys", "once_cell", "strum 0.26.3", "tokio", - "tokio_with_wasm", "tracing", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-time", +] + +[[package]] +name = "ractor-supervisor" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d90830688ebfafdc226f3c9567c40fecf4c51a7513171181102ae66e4b57c15f" +dependencies = [ + "futures-util", + "if_chain", + "log", + "ractor", + "thiserror 2.0.17", + "uuid", ] [[package]] @@ -14656,6 +14672,7 @@ dependencies = [ "owhisper-client", "owhisper-interface", "ractor", + "ractor-supervisor", "rodio", "serde", "serde_json", @@ -15858,30 +15875,6 @@ dependencies = [ "webpki-roots 0.26.11", ] -[[package]] -name = "tokio_with_wasm" -version = "0.8.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dfba9b946459940fb564dcf576631074cdfb0bfe4c962acd4c31f0dca7897e6" -dependencies = [ - "js-sys", - "tokio", - "tokio_with_wasm_proc", - "wasm-bindgen", - "wasm-bindgen-futures", - "web-sys", -] - -[[package]] -name = "tokio_with_wasm_proc" -version = "0.8.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37e04c1865c281139e5ccf633cb9f76ffdaabeebfe53b703984cf82878e2aabb" -dependencies = [ - "quote", - "syn 2.0.108", -] - [[package]] name = "toml" version = "0.8.23" diff --git a/Cargo.toml b/Cargo.toml index 9d083bdfc..13533c6b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -132,7 +132,8 @@ async-stream = "0.3.6" futures-channel = "0.3.31" futures-core = "0.3.31" futures-util = "0.3.31" -ractor = "0.15" +ractor = { version = "0.14.3" } +ractor-supervisor = "0.1.9" reqwest = "0.12" reqwest-streams = "0.10.0" tokio = "1" diff --git a/Taskfile.yaml b/Taskfile.yaml index b77a6858e..ea3188aaf 100644 --- a/Taskfile.yaml +++ b/Taskfile.yaml @@ -72,3 +72,12 @@ tasks: cmds: - chmod +x ./apps/desktop/src-tauri/resources/stt-aarch64-apple-darwin - chmod +x ./apps/desktop/src-tauri/resources/passthrough-aarch64-apple-darwin + + db: + env: + DB: /Users/yujonglee/Library/Application Support/com.hyprnote.nightly/db.sqlite + cmds: + - | + sqlite3 -json "$DB" 'SELECT store FROM main LIMIT 1;' | + jq -r '.[0].store' | + jless diff --git a/apps/desktop/src/components/main/body/sessions/floating/listen.tsx b/apps/desktop/src/components/main/body/sessions/floating/listen.tsx index b3599eb35..745207efc 100644 --- a/apps/desktop/src/components/main/body/sessions/floating/listen.tsx +++ b/apps/desktop/src/components/main/body/sessions/floating/listen.tsx @@ -200,7 +200,7 @@ function OptionsMenu({ queryClient.invalidateQueries({ queryKey: ["audio", sessionId, "url"] }); }) ), - Effect.flatMap((importedPath) => Effect.promise(() => runBatch(importedPath, { channels: 1 }))), + Effect.flatMap((importedPath) => Effect.promise(() => runBatch(importedPath))), ); }, [queryClient, runBatch, sessionId], diff --git a/apps/desktop/src/components/main/body/sessions/note-input/transcript/shared/hooks.ts b/apps/desktop/src/components/main/body/sessions/note-input/transcript/shared/hooks.ts index 4da042101..30672ac81 100644 --- a/apps/desktop/src/components/main/body/sessions/note-input/transcript/shared/hooks.ts +++ b/apps/desktop/src/components/main/body/sessions/note-input/transcript/shared/hooks.ts @@ -13,9 +13,11 @@ export function useFinalWords(transcriptId: string): (main.Word & { id: string } return []; } - return Object.entries(resultTable) + const ret = Object.entries(resultTable) .map(([wordId, row]) => ({ ...(row as unknown as main.Word), id: wordId })) .sort((a, b) => a.start_ms - b.start_ms); + + return ret; }, [resultTable]); } diff --git a/apps/desktop/src/hooks/useAutoEnhance.ts b/apps/desktop/src/hooks/useAutoEnhance.ts index 335e16d26..e3138a889 100644 --- a/apps/desktop/src/hooks/useAutoEnhance.ts +++ b/apps/desktop/src/hooks/useAutoEnhance.ts @@ -59,12 +59,7 @@ export function useAutoEnhance(tab: Extract) { if (listenerJustStopped) { startEnhance(); - } - }, [listenerStatus, prevListenerStatus, startEnhance]); - - useEffect(() => { - if (enhanceTask.status === "generating" && tab.state.editor !== "enhanced") { updateSessionTabState(tab, { editor: "enhanced" }); } - }, [enhanceTask.status, tab, updateSessionTabState]); + }, [listenerStatus, prevListenerStatus, startEnhance]); } diff --git a/apps/desktop/src/hooks/useRunBatch.ts b/apps/desktop/src/hooks/useRunBatch.ts index fcc76e69b..7a635ca7b 100644 --- a/apps/desktop/src/hooks/useRunBatch.ts +++ b/apps/desktop/src/hooks/useRunBatch.ts @@ -13,7 +13,6 @@ import { useSTTConnection } from "./useSTTConnection"; type RunOptions = { handlePersist?: HandlePersistCallback; - channels?: number; model?: string; baseUrl?: string; apiKey?: string; @@ -138,7 +137,6 @@ export const useRunBatch = (sessionId: string) => { api_key: options?.apiKey ?? conn.apiKey, keywords: options?.keywords ?? keywords ?? [], languages: options?.languages ?? languages ?? [], - channels: options?.channels, }; await runBatch(params, { handlePersist: persist, sessionId }); diff --git a/crates/audio-utils/src/lib.rs b/crates/audio-utils/src/lib.rs index 557c48a4a..609deffc2 100644 --- a/crates/audio-utils/src/lib.rs +++ b/crates/audio-utils/src/lib.rs @@ -1,3 +1,5 @@ +use std::convert::TryFrom; + use bytes::{BufMut, Bytes, BytesMut}; use futures_util::{Stream, StreamExt}; use kalosm_sound::AsyncSource; @@ -11,6 +13,12 @@ pub use rodio::Source; const I16_SCALE: f32 = 32768.0; +#[derive(Debug, Clone, Copy)] +pub struct AudioMetadata { + pub sample_rate: u32, + pub channels: u8, +} + impl AudioFormatExt for T {} pub trait AudioFormatExt: AsyncSource { @@ -81,6 +89,40 @@ pub fn source_from_path( Ok(decoder) } +fn metadata_from_source(source: &S) -> Result +where + S: Source, + S::Item: rodio::Sample, +{ + let sample_rate = source.sample_rate(); + if sample_rate == 0 { + return Err(crate::Error::InvalidSampleRate(sample_rate)); + } + + let channels_u16 = source.channels(); + if channels_u16 == 0 { + return Err(crate::Error::UnsupportedChannelCount { + count: channels_u16, + }); + } + let channels = + u8::try_from(channels_u16).map_err(|_| crate::Error::UnsupportedChannelCount { + count: channels_u16, + })?; + + Ok(AudioMetadata { + sample_rate, + channels, + }) +} + +pub fn audio_file_metadata( + path: impl AsRef, +) -> Result { + let source = source_from_path(path)?; + metadata_from_source(&source) +} + pub fn resample_audio(source: S, to_rate: u32) -> Result, crate::Error> where S: rodio::Source + Iterator, @@ -136,32 +178,48 @@ where pub struct ChunkedAudio { pub chunks: Vec, pub sample_count: usize, + pub frame_count: usize, + pub metadata: AudioMetadata, } pub fn chunk_audio_file( path: impl AsRef, - sample_rate: u32, - chunk_size: usize, + chunk_ms: u64, ) -> Result { let source = source_from_path(path)?; - let samples = resample_audio(source, sample_rate)?; + let metadata = metadata_from_source(&source)?; + let samples = resample_audio(source, metadata.sample_rate)?; if samples.is_empty() { return Ok(ChunkedAudio { chunks: Vec::new(), sample_count: 0, + frame_count: 0, + metadata, }); } - let chunk_size = chunk_size.max(1); + let channels = metadata.channels.max(1) as usize; + let frames_per_chunk = { + let frames = ((chunk_ms as u128).saturating_mul(metadata.sample_rate as u128) + 999) / 1000; + frames.max(1).min(usize::MAX as u128) as usize + }; + let samples_per_chunk = frames_per_chunk + .saturating_mul(channels) + .max(1) + .min(usize::MAX); + let sample_count = samples.len(); + let frame_count = sample_count / channels; let chunks = samples - .chunks(chunk_size) + .chunks(samples_per_chunk) .map(|chunk| f32_to_i16_bytes(chunk.iter().copied())) .collect(); Ok(ChunkedAudio { chunks, sample_count, + frame_count, + metadata, }) } diff --git a/crates/audio/src/mic.rs b/crates/audio/src/mic.rs index d91233c7c..050f3fbaa 100644 --- a/crates/audio/src/mic.rs +++ b/crates/audio/src/mic.rs @@ -65,6 +65,10 @@ impl MicInput { config, }) } + + pub fn sample_rate(&self) -> u32 { + self.config.sample_rate().0 + } } impl MicInput { diff --git a/crates/audio/src/speaker/linux.rs b/crates/audio/src/speaker/linux.rs index 213b30e72..289b503f7 100644 --- a/crates/audio/src/speaker/linux.rs +++ b/crates/audio/src/speaker/linux.rs @@ -7,6 +7,10 @@ impl SpeakerInput { Self {} } + pub fn sample_rate(&self) -> u32 { + 16000 + } + pub fn stream(self) -> SpeakerStream { SpeakerStream::new() } diff --git a/crates/audio/src/speaker/macos.rs b/crates/audio/src/speaker/macos.rs index 82288a73b..b38cd9b92 100644 --- a/crates/audio/src/speaker/macos.rs +++ b/crates/audio/src/speaker/macos.rs @@ -91,6 +91,10 @@ impl SpeakerInput { Ok(Self { tap, agg_desc }) } + pub fn sample_rate(&self) -> u32 { + self.tap.asbd().unwrap().sample_rate as u32 + } + fn start_device( &self, ctx: &mut Box, diff --git a/crates/audio/src/speaker/mod.rs b/crates/audio/src/speaker/mod.rs index 41905a87c..cf9b1c50f 100644 --- a/crates/audio/src/speaker/mod.rs +++ b/crates/audio/src/speaker/mod.rs @@ -42,6 +42,16 @@ impl SpeakerInput { )) } + #[cfg(any(target_os = "macos", target_os = "windows"))] + pub fn sample_rate(&self) -> u32 { + self.inner.sample_rate() + } + + #[cfg(not(any(target_os = "macos", target_os = "windows")))] + pub fn sample_rate(&self) -> u32 { + 0 + } + #[cfg(any(target_os = "macos", target_os = "windows"))] pub fn stream(self) -> Result { let inner = self.inner.stream(); diff --git a/crates/audio/src/speaker/windows.rs b/crates/audio/src/speaker/windows.rs index 83e9f2d3c..60377c588 100644 --- a/crates/audio/src/speaker/windows.rs +++ b/crates/audio/src/speaker/windows.rs @@ -15,6 +15,10 @@ impl SpeakerInput { Ok(Self {}) } + pub fn sample_rate(&self) -> u32 { + 44100 + } + pub fn stream(self) -> SpeakerStream { let sample_queue = Arc::new(Mutex::new(VecDeque::new())); let waker_state = Arc::new(Mutex::new(WakerState { diff --git a/owhisper/owhisper-client/src/batch.rs b/owhisper/owhisper-client/src/batch.rs index cf1c9d46b..b9c7dca03 100644 --- a/owhisper/owhisper-client/src/batch.rs +++ b/owhisper/owhisper-client/src/batch.rs @@ -4,8 +4,10 @@ use tokio::task; use hypr_audio_utils::{f32_to_i16_bytes, resample_audio, source_from_path, Source}; use owhisper_interface::batch::Response as BatchResponse; -use crate::{error::Error, ListenClientBuilder, RESAMPLED_SAMPLE_RATE_HZ}; +use crate::{error::Error, ListenClientBuilder}; +// https://developers.deepgram.com/reference/speech-to-text/listen-pre-recorded +// https://github.com/deepgram/deepgram-rust-sdk/blob/main/src/listen/rest.rs #[derive(Clone)] pub struct BatchClient { pub(crate) client: reqwest::Client, @@ -13,29 +15,6 @@ pub struct BatchClient { pub(crate) api_key: Option, } -async fn decode_audio_to_linear16(path: PathBuf) -> Result<(bytes::Bytes, u16), Error> { - task::spawn_blocking(move || -> Result<(bytes::Bytes, u16), Error> { - let decoder = - source_from_path(&path).map_err(|err| Error::AudioProcessing(err.to_string()))?; - - let channel_count = decoder.channels(); - - let samples = resample_audio(decoder, RESAMPLED_SAMPLE_RATE_HZ) - .map_err(|err| Error::AudioProcessing(err.to_string()))?; - - if samples.is_empty() { - return Err(Error::AudioProcessing( - "audio file contains no samples".to_string(), - )); - } - - let bytes = f32_to_i16_bytes(samples.into_iter()); - - Ok((bytes, channel_count)) - }) - .await? -} - impl BatchClient { pub fn builder() -> ListenClientBuilder { ListenClientBuilder::default() @@ -46,14 +25,31 @@ impl BatchClient { file_path: P, ) -> Result { let path = file_path.as_ref(); - let (audio_data, channel_count) = decode_audio_to_linear16(path.to_path_buf()).await?; + let (audio_data, sample_rate) = decode_audio_to_linear16(path.to_path_buf()).await?; - let mut url = self.url.clone(); - let channel_value = channel_count.max(1).to_string(); - { - let mut query_pairs = url.query_pairs_mut(); - query_pairs.append_pair("channels", &channel_value); - } + let params = { + let mut params: Vec<(String, String)> = vec![]; + params.retain(|(key, _)| key != "channels"); + + params.push(("sample_rate".to_string(), sample_rate.to_string())); + params.push(("multichannel".to_string(), "false".to_string())); + params.push(("diarize".to_string(), "true".to_string())); + params.push(("detect_language".to_string(), "true".to_string())); + params + }; + + let url = { + let mut url = self.url.clone(); + + let mut serializer = url::form_urlencoded::Serializer::new(String::new()); + for (key, value) in params { + serializer.append_pair(&key, &value); + } + + let query = serializer.finish(); + url.set_query(Some(&query)); + url + }; let mut request = self.client.post(url); @@ -61,10 +57,7 @@ impl BatchClient { request = request.header("Authorization", format!("Token {}", key)); } - let content_type = format!( - "audio/raw;encoding=linear16;rate={}", - RESAMPLED_SAMPLE_RATE_HZ - ); + let content_type = format!("audio/raw;encoding=linear16;rate={}", sample_rate); let response = request .header("Accept", "application/json") @@ -84,3 +77,42 @@ impl BatchClient { } } } + +async fn decode_audio_to_linear16(path: PathBuf) -> Result<(bytes::Bytes, u32), Error> { + task::spawn_blocking(move || -> Result<(bytes::Bytes, u32), Error> { + let decoder = + source_from_path(&path).map_err(|err| Error::AudioProcessing(err.to_string()))?; + + let channels = decoder.channels().max(1); + let sample_rate = decoder.sample_rate(); + + let samples = resample_audio(decoder, sample_rate) + .map_err(|err| Error::AudioProcessing(err.to_string()))?; + + let samples = if channels == 1 { + samples + } else { + let channels_usize = channels as usize; + let mut mono = Vec::with_capacity(samples.len() / channels_usize); + for frame in samples.chunks(channels_usize) { + if frame.is_empty() { + continue; + } + let sum: f32 = frame.iter().copied().sum(); + mono.push(sum / frame.len() as f32); + } + mono + }; + + if samples.is_empty() { + return Err(Error::AudioProcessing( + "audio file contains no samples".to_string(), + )); + } + + let bytes = f32_to_i16_bytes(samples.into_iter()); + + Ok((bytes, sample_rate)) + }) + .await? +} diff --git a/owhisper/owhisper-client/src/lib.rs b/owhisper/owhisper-client/src/lib.rs index fc78422b6..22a141831 100644 --- a/owhisper/owhisper-client/src/lib.rs +++ b/owhisper/owhisper-client/src/lib.rs @@ -10,8 +10,6 @@ pub use error::Error; pub use hypr_ws; pub use live::{ListenClient, ListenClientDual}; -const RESAMPLED_SAMPLE_RATE_HZ: u32 = 16_000; - #[derive(Default)] pub struct ListenClientBuilder { api_base: Option, @@ -63,7 +61,7 @@ impl ListenClientBuilder { append_language_query(&mut query_pairs, ¶ms); let model = params.model.as_deref().unwrap_or("hypr-whisper"); - let sample_rate = RESAMPLED_SAMPLE_RATE_HZ.to_string(); + let sample_rate = params.sample_rate.to_string(); query_pairs.append_pair("model", model); query_pairs.append_pair("encoding", "linear16"); @@ -104,7 +102,7 @@ impl ListenClientBuilder { let model = params.model.as_deref().unwrap_or("hypr-whisper"); let channel_string = channels.to_string(); - let sample_rate = RESAMPLED_SAMPLE_RATE_HZ.to_string(); + let sample_rate = params.sample_rate.to_string(); query_pairs.append_pair("model", model); query_pairs.append_pair("channels", &channel_string); diff --git a/owhisper/owhisper-interface/src/lib.rs b/owhisper/owhisper-interface/src/lib.rs index 351c955f7..424a3a0d6 100644 --- a/owhisper/owhisper-interface/src/lib.rs +++ b/owhisper/owhisper-interface/src/lib.rs @@ -137,6 +137,7 @@ common_derives! { #[serde(default)] pub model: Option, pub channels: u8, + pub sample_rate: u32, // https://docs.rs/axum-extra/0.10.1/axum_extra/extract/struct.Query.html#example-1 #[serde(default)] pub languages: Vec, @@ -152,6 +153,7 @@ impl Default for ListenParams { ListenParams { model: None, channels: 1, + sample_rate: 16000, languages: vec![], keywords: vec![], redemption_time_ms: None, diff --git a/plugins/listener/Cargo.toml b/plugins/listener/Cargo.toml index 5a71ff253..40bf28bda 100644 --- a/plugins/listener/Cargo.toml +++ b/plugins/listener/Cargo.toml @@ -52,8 +52,10 @@ uuid = { workspace = true, features = ["v4"] } hound = { workspace = true } vorbis_rs = { workspace = true } -futures-util = { workspace = true } ractor = { workspace = true } +ractor-supervisor = { workspace = true } + +futures-util = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tokio-stream = { workspace = true } tokio-util = { workspace = true } diff --git a/plugins/listener/js/bindings.gen.ts b/plugins/listener/js/bindings.gen.ts index b1a61dbce..7555ad3ed 100644 --- a/plugins/listener/js/bindings.gen.ts +++ b/plugins/listener/js/bindings.gen.ts @@ -98,7 +98,7 @@ sessionEvent: "plugin:listener:session-event" export type BatchAlternatives = { transcript: string; confidence: number; words?: BatchWord[] } export type BatchChannel = { alternatives: BatchAlternatives[] } -export type BatchParams = { session_id: string; provider: BatchProvider; file_path: string; model?: string | null; base_url: string; api_key: string; languages?: string[]; keywords?: string[]; channels?: number | null } +export type BatchParams = { session_id: string; provider: BatchProvider; file_path: string; model?: string | null; base_url: string; api_key: string; languages?: string[]; keywords?: string[] } export type BatchProvider = "deepgram" | "am" export type BatchResponse = { metadata: JsonValue; results: BatchResults } export type BatchResults = { channels: BatchChannel[] } diff --git a/plugins/listener/src/actors/batch.rs b/plugins/listener/src/actors/batch.rs index cd71c3084..c0ab160a2 100644 --- a/plugins/listener/src/actors/batch.rs +++ b/plugins/listener/src/actors/batch.rs @@ -4,13 +4,11 @@ use std::time::Duration; use owhisper_interface::stream::StreamResponse; use owhisper_interface::{ControlMessage, MixedMessage}; -use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef}; +use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef, SpawnErr}; use tauri_specta::Event; use tokio_stream::{self as tokio_stream, StreamExt as TokioStreamExt}; use crate::SessionEvent; - -const RESAMPLED_SAMPLE_RATE_HZ: u32 = 16_000; const BATCH_STREAM_TIMEOUT_SECS: u64 = 5; const DEFAULT_CHUNK_MS: u64 = 500; const DEFAULT_DELAY_MS: u64 = 20; @@ -91,6 +89,12 @@ impl BatchActor { } } +pub async fn spawn_batch_actor(args: BatchArgs) -> Result, SpawnErr> { + let (batch_ref, _) = Actor::spawn(Some(BatchActor::name()), BatchActor, args).await?; + Ok(batch_ref) +} + +#[ractor::async_trait] impl Actor for BatchActor { type Msg = BatchMsg; type State = BatchState; @@ -188,13 +192,6 @@ impl BatchStreamConfig { } } - fn chunk_samples(&self) -> usize { - let samples = - ((self.chunk_ms as u128).saturating_mul(RESAMPLED_SAMPLE_RATE_HZ as u128) + 999) / 1000; - let samples = samples.max(1); - samples.min(usize::MAX as u128) as usize - } - fn chunk_interval(&self) -> Duration { Duration::from_millis(self.delay_ms) } @@ -225,12 +222,10 @@ async fn spawn_batch_task( let stream_config = BatchStreamConfig::new(DEFAULT_CHUNK_MS, DEFAULT_DELAY_MS); let start_notifier = args.start_notifier.clone(); - let chunk_samples = stream_config.chunk_samples(); let chunk_result = tokio::task::spawn_blocking({ let path = PathBuf::from(&args.file_path); - move || { - hypr_audio_utils::chunk_audio_file(path, RESAMPLED_SAMPLE_RATE_HZ, chunk_samples) - } + let chunk_ms = stream_config.chunk_ms; + move || hypr_audio_utils::chunk_audio_file(path, chunk_ms) }) .await; @@ -258,20 +253,25 @@ async fn spawn_batch_task( } }; - let sample_count = chunked_audio.sample_count; - let audio_duration_secs = if sample_count == 0 { + let frame_count = chunked_audio.frame_count; + let metadata = chunked_audio.metadata; + let audio_duration_secs = if frame_count == 0 || metadata.sample_rate == 0 { 0.0 } else { - sample_count as f64 / RESAMPLED_SAMPLE_RATE_HZ as f64 + frame_count as f64 / metadata.sample_rate as f64 }; let _ = myself.send_message(BatchMsg::StreamAudioDuration(audio_duration_secs)); - tracing::debug!("batch task: creating listen client"); - let channel_count = args.listen_params.channels.clamp(1, 2); + let channel_count = metadata.channels.clamp(1, 2); + let listen_params = owhisper_interface::ListenParams { + channels: metadata.channels, + sample_rate: metadata.sample_rate, + ..args.listen_params.clone() + }; let client = owhisper_client::ListenClient::builder() .api_base(args.base_url) .api_key(args.api_key) - .params(args.listen_params.clone()) + .params(listen_params) .build_with_channels(channel_count); let chunk_count = chunked_audio.chunks.len(); diff --git a/plugins/listener/src/actors/listener.rs b/plugins/listener/src/actors/listener.rs index 733b0a204..e63384c23 100644 --- a/plugins/listener/src/actors/listener.rs +++ b/plugins/listener/src/actors/listener.rs @@ -54,6 +54,7 @@ impl ListenerActor { } } +#[ractor::async_trait] impl Actor for ListenerActor { type Msg = ListenerMsg; type State = ListenerState; diff --git a/plugins/listener/src/actors/processor.rs b/plugins/listener/src/actors/processor.rs index fec275c9e..ad3daadf9 100644 --- a/plugins/listener/src/actors/processor.rs +++ b/plugins/listener/src/actors/processor.rs @@ -55,6 +55,7 @@ impl ProcessorActor { } } +#[ractor::async_trait] impl Actor for ProcessorActor { type Msg = ProcMsg; type State = ProcState; diff --git a/plugins/listener/src/actors/recorder.rs b/plugins/listener/src/actors/recorder.rs index 241d86bb7..815007ac4 100644 --- a/plugins/listener/src/actors/recorder.rs +++ b/plugins/listener/src/actors/recorder.rs @@ -34,6 +34,7 @@ impl RecorderActor { } } +#[ractor::async_trait] impl Actor for RecorderActor { type Msg = RecMsg; type State = RecState; diff --git a/plugins/listener/src/actors/session.rs b/plugins/listener/src/actors/session.rs index 5d9d3c5f2..5b1cb4c07 100644 --- a/plugins/listener/src/actors/session.rs +++ b/plugins/listener/src/actors/session.rs @@ -58,6 +58,7 @@ impl SessionActor { } } +#[ractor::async_trait] impl Actor for SessionActor { type Msg = SessionMsg; type State = SessionState; diff --git a/plugins/listener/src/actors/source.rs b/plugins/listener/src/actors/source.rs index 1f8378d5c..6e9a670fa 100644 --- a/plugins/listener/src/actors/source.rs +++ b/plugins/listener/src/actors/source.rs @@ -50,6 +50,7 @@ impl SourceActor { } } +#[ractor::async_trait] impl Actor for SourceActor { type Msg = SourceMsg; type State = SourceState; diff --git a/plugins/listener/src/events.rs b/plugins/listener/src/events.rs index 99b628a22..90dd2657c 100644 --- a/plugins/listener/src/events.rs +++ b/plugins/listener/src/events.rs @@ -40,6 +40,7 @@ impl From<(&[f32], &[f32])> for SessionEvent { let mic = (mic_chunk .iter() .map(|&x| x.abs()) + .filter(|x| x.is_finite()) .max_by(|a, b| a.partial_cmp(b).unwrap()) .unwrap_or(0.0) * 100.0) as u16; @@ -47,6 +48,7 @@ impl From<(&[f32], &[f32])> for SessionEvent { let speaker = (speaker_chunk .iter() .map(|&x| x.abs()) + .filter(|x| x.is_finite()) .max_by(|a, b| a.partial_cmp(b).unwrap()) .unwrap_or(0.0) * 100.0) as u16; diff --git a/plugins/listener/src/ext.rs b/plugins/listener/src/ext.rs index 8cfba9478..4055ef515 100644 --- a/plugins/listener/src/ext.rs +++ b/plugins/listener/src/ext.rs @@ -4,8 +4,9 @@ use std::sync::{Arc, Mutex}; use ractor::{call_t, concurrency, registry, Actor, ActorRef}; use tauri_specta::Event; +use crate::actors::spawn_batch_actor; use crate::{ - actors::{BatchActor, BatchArgs, SessionActor, SessionArgs, SessionMsg, SessionParams}, + actors::{BatchArgs, SessionActor, SessionArgs, SessionMsg, SessionParams}, SessionEvent, }; @@ -29,8 +30,6 @@ pub struct BatchParams { pub languages: Vec, #[serde(default)] pub keywords: Vec, - #[serde(default)] - pub channels: Option, } pub trait ListenerPluginExt { @@ -150,11 +149,22 @@ impl> ListenerPluginExt for T { #[tracing::instrument(skip_all)] async fn run_batch(&self, params: BatchParams) -> Result<(), crate::Error> { - let channels = params.channels.unwrap_or(1); + let metadata = tokio::task::spawn_blocking({ + let path = params.file_path.clone(); + move || hypr_audio_utils::audio_file_metadata(path) + }) + .await + .map_err(|err| { + crate::Error::BatchStartFailed(format!("failed to join audio metadata task: {err:?}")) + })? + .map_err(|err| { + crate::Error::BatchStartFailed(format!("failed to read audio metadata: {err}")) + })?; let listen_params = owhisper_interface::ListenParams { model: params.model.clone(), - channels, + channels: metadata.channels, + sample_rate: metadata.sample_rate, languages: params.languages.clone(), keywords: params.keywords.clone(), redemption_time_ms: None, @@ -171,20 +181,16 @@ impl> ListenerPluginExt for T { let app = guard.app.clone(); drop(guard); - match Actor::spawn( - Some(BatchActor::name()), - BatchActor, - BatchArgs { - app, - file_path: params.file_path.clone(), - base_url: params.base_url.clone(), - api_key: params.api_key.clone(), - listen_params, - start_notifier: start_notifier.clone(), - }, - ) - .await - { + let args = BatchArgs { + app, + file_path: params.file_path.clone(), + base_url: params.base_url.clone(), + api_key: params.api_key.clone(), + listen_params: listen_params.clone(), + start_notifier: start_notifier.clone(), + }; + + match spawn_batch_actor(args).await { Ok(_) => { tracing::info!("batch actor spawned successfully"); let state = self.state::(); @@ -196,11 +202,11 @@ impl> ListenerPluginExt for T { .unwrap(); } Err(e) => { - tracing::error!("batch actor spawn failed: {:?}", e); + tracing::error!("batch supervisor spawn failed: {:?}", e); if let Ok(mut notifier) = start_notifier.lock() { if let Some(tx) = notifier.take() { - let _ = - tx.send(Err(format!("failed to spawn batch actor: {:?}", e))); + let _ = tx + .send(Err(format!("failed to spawn batch supervisor: {e:?}"))); } } return Err(e.into()); diff --git a/plugins/local-stt/src/server/external.rs b/plugins/local-stt/src/server/external.rs index d91e9e1b8..4f8d87ac3 100644 --- a/plugins/local-stt/src/server/external.rs +++ b/plugins/local-stt/src/server/external.rs @@ -37,6 +37,7 @@ impl ExternalSTTActor { } } +#[ractor::async_trait] impl Actor for ExternalSTTActor { type Msg = ExternalSTTMessage; type State = ExternalSTTState; @@ -61,6 +62,8 @@ impl Actor for ExternalSTTActor { let text = text.trim(); if !text.is_empty() && !text.contains("[WebSocket]") + && !text.contains("Sent interim text:") + && !text.contains("[TranscriptionHandler]") && !text.contains("/v1/status") { tracing::info!("{}", text); diff --git a/plugins/local-stt/src/server/internal.rs b/plugins/local-stt/src/server/internal.rs index 01d715101..90344ce7f 100644 --- a/plugins/local-stt/src/server/internal.rs +++ b/plugins/local-stt/src/server/internal.rs @@ -36,6 +36,7 @@ impl InternalSTTActor { } } +#[ractor::async_trait] impl Actor for InternalSTTActor { type Msg = InternalSTTMessage; type State = InternalSTTState;