diff --git a/Cargo.lock b/Cargo.lock index 4b10d0e6a..80ed1291e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -12,12 +12,6 @@ dependencies = [ "regex", ] -[[package]] -name = "accelerate-src" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "415ed64958754dbe991900f3940677e6a7eefb4d7367afd70d642677b0c7d19d" - [[package]] name = "actix-codec" version = "0.5.2" @@ -1009,6 +1003,9 @@ dependencies = [ "bytes", "futures-util", "kalosm-sound", + "rodio", + "rubato", + "thiserror 2.0.12", ] [[package]] @@ -2197,14 +2194,10 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06ccf5ee3532e66868516d9b315f73aec9f34ea1a37ae98514534d458915dbf1" dependencies = [ - "accelerate-src", "byteorder", - "candle-metal-kernels", "gemm 0.17.1", "half", - "libc", "memmap2", - "metal 0.27.0", "num-traits", "num_cpus", "rand 0.9.1", @@ -2213,34 +2206,18 @@ dependencies = [ "safetensors", "thiserror 1.0.69", "ug", - "ug-metal", "yoke 0.7.5", "zip 1.1.4", ] -[[package]] -name = "candle-metal-kernels" -version = "0.8.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52c85c21827c28db94e7112e364abe7e0cf8d2b022c014edf08642be6b94f21e" -dependencies = [ - "metal 0.27.0", - "once_cell", - "thiserror 1.0.69", - "tracing", -] - [[package]] name = "candle-nn" version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be1160c3b63f47d40d91110a3e1e1e566ae38edddbbf492a60b40ffc3bc1ff38" dependencies = [ - "accelerate-src", "candle-core", - "candle-metal-kernels", "half", - "metal 0.27.0", "num-traits", "rayon", "safetensors", @@ -2254,7 +2231,6 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94a0900d49f8605e0e7e6693a1f560e6271279de98e5fa369e7abf3aac245020" dependencies = [ - "accelerate-src", "byteorder", "candle-core", "candle-nn", @@ -2649,7 +2625,7 @@ dependencies = [ "bitflags 2.9.1", "block", "core-foundation 0.10.0", - "core-graphics-types 0.2.0", + "core-graphics-types", "objc", ] @@ -2845,22 +2821,11 @@ checksum = "fa95a34622365fa5bbf40b20b75dba8dfa8c94c734aea8ac9a5ca38af14316f1" dependencies = [ "bitflags 2.9.1", "core-foundation 0.10.0", - "core-graphics-types 0.2.0", + "core-graphics-types", "foreign-types 0.5.0", "libc", ] -[[package]] -name = "core-graphics-types" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45390e6114f68f718cc7a830514a96f903cccd70d02a8f6d9f643ac4ba45afaf" -dependencies = [ - "bitflags 1.3.2", - "core-foundation 0.9.4", - "libc", -] - [[package]] name = "core-graphics-types" version = "0.2.0" @@ -4230,6 +4195,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "extended" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af9673d8203fcb076b19dfd17e38b3d4ae9f44959416ea532ce72415a6020365" + [[package]] name = "eyre" version = "0.6.12" @@ -7446,7 +7417,6 @@ dependencies = [ "hf-hub", "httpdate", "kalosm-model-types", - "metal 0.29.0", "reqwest 0.11.27", "thiserror 2.0.12", "tokio", @@ -8277,36 +8247,6 @@ dependencies = [ "autocfg", ] -[[package]] -name = "metal" -version = "0.27.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c43f73953f8cbe511f021b58f18c3ce1c3d1ae13fe953293e13345bf83217f25" -dependencies = [ - "bitflags 2.9.1", - "block", - "core-graphics-types 0.1.3", - "foreign-types 0.5.0", - "log", - "objc", - "paste", -] - -[[package]] -name = "metal" -version = "0.29.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ecfd3296f8c56b7c1f6fbac3c71cefa9d78ce009850c45000015f206dc7fa21" -dependencies = [ - "bitflags 2.9.1", - "block", - "core-graphics-types 0.1.3", - "foreign-types 0.5.0", - "log", - "objc", - "paste", -] - [[package]] name = "mime" version = "0.3.17" @@ -8965,7 +8905,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "915b1b472bc21c53464d6c8461c9d3af805ba1ef837e1cac254428f4a77177b1" dependencies = [ "malloc_buf", - "objc_exception", ] [[package]] @@ -9361,15 +9300,6 @@ dependencies = [ "objc2-foundation 0.3.1", ] -[[package]] -name = "objc_exception" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad970fb455818ad6cba4c122ad012fae53ae8b4795f86378bce65e4f6bab2ca4" -dependencies = [ - "cc", -] - [[package]] name = "objc_id" version = "0.1.1" @@ -11481,6 +11411,18 @@ dependencies = [ "tonic-build", ] +[[package]] +name = "rubato" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5258099699851cfd0082aeb645feb9c084d9a5e1f1b8d5372086b989fc5e56a1" +dependencies = [ + "num-complex", + "num-integer", + "num-traits", + "realfft", +] + [[package]] name = "rust-ini" version = "0.21.1" @@ -11739,7 +11681,6 @@ name = "rwhisper" version = "0.4.1" source = "git+https://github.com/floneum/floneum?rev=52967ae#52967ae5dcd161e5cbe9507282b52df0946706f9" dependencies = [ - "accelerate-src", "byteorder", "candle-core", "candle-nn", @@ -13230,9 +13171,28 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "815c942ae7ee74737bb00f965fa5b5a2ac2ce7b6c01c0cc169bbeaf7abd5f5a9" dependencies = [ "lazy_static", + "symphonia-bundle-flac", "symphonia-bundle-mp3", + "symphonia-codec-aac", + "symphonia-codec-adpcm", + "symphonia-codec-pcm", + "symphonia-codec-vorbis", + "symphonia-core", + "symphonia-format-isomp4", + "symphonia-format-riff", + "symphonia-metadata", +] + +[[package]] +name = "symphonia-bundle-flac" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72e34f34298a7308d4397a6c7fbf5b84c5d491231ce3dd379707ba673ab3bd97" +dependencies = [ + "log", "symphonia-core", "symphonia-metadata", + "symphonia-utils-xiph", ] [[package]] @@ -13247,6 +13207,48 @@ dependencies = [ "symphonia-metadata", ] +[[package]] +name = "symphonia-codec-aac" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdbf25b545ad0d3ee3e891ea643ad115aff4ca92f6aec472086b957a58522f70" +dependencies = [ + "lazy_static", + "log", + "symphonia-core", +] + +[[package]] +name = "symphonia-codec-adpcm" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c94e1feac3327cd616e973d5be69ad36b3945f16b06f19c6773fc3ac0b426a0f" +dependencies = [ + "log", + "symphonia-core", +] + +[[package]] +name = "symphonia-codec-pcm" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f395a67057c2ebc5e84d7bb1be71cce1a7ba99f64e0f0f0e303a03f79116f89b" +dependencies = [ + "log", + "symphonia-core", +] + +[[package]] +name = "symphonia-codec-vorbis" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a98765fb46a0a6732b007f7e2870c2129b6f78d87db7987e6533c8f164a9f30" +dependencies = [ + "log", + "symphonia-core", + "symphonia-utils-xiph", +] + [[package]] name = "symphonia-core" version = "0.5.4" @@ -13260,6 +13262,31 @@ dependencies = [ "log", ] +[[package]] +name = "symphonia-format-isomp4" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abfdf178d697e50ce1e5d9b982ba1b94c47218e03ec35022d9f0e071a16dc844" +dependencies = [ + "encoding_rs", + "log", + "symphonia-core", + "symphonia-metadata", + "symphonia-utils-xiph", +] + +[[package]] +name = "symphonia-format-riff" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f7be232f962f937f4b7115cbe62c330929345434c834359425e043bfd15f50" +dependencies = [ + "extended", + "log", + "symphonia-core", + "symphonia-metadata", +] + [[package]] name = "symphonia-metadata" version = "0.5.4" @@ -13272,6 +13299,16 @@ dependencies = [ "symphonia-core", ] +[[package]] +name = "symphonia-utils-xiph" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "484472580fa49991afda5f6550ece662237b00c6f562c7d9638d1b086ed010fe" +dependencies = [ + "symphonia-core", + "symphonia-metadata", +] + [[package]] name = "syn" version = "1.0.109" @@ -13996,8 +14033,6 @@ dependencies = [ "dirs 6.0.0", "file", "futures-util", - "kalosm-common", - "kalosm-sound", "language", "listener-interface", "pyannote-local", @@ -15573,20 +15608,6 @@ dependencies = [ "yoke 0.7.5", ] -[[package]] -name = "ug-metal" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a02ddc17bf32f7dcaaf016b6735f7198082b82f122df7b3ca15d8ead5911ccef" -dependencies = [ - "half", - "metal 0.29.0", - "objc", - "serde", - "thiserror 1.0.69", - "ug", -] - [[package]] name = "ulid" version = "1.2.1" diff --git a/crates/audio-utils/Cargo.toml b/crates/audio-utils/Cargo.toml index 059ac0ba3..fd9712a39 100644 --- a/crates/audio-utils/Cargo.toml +++ b/crates/audio-utils/Cargo.toml @@ -7,3 +7,7 @@ edition = "2021" bytes = { workspace = true } futures-util = { workspace = true } kalosm-sound = { workspace = true, default-features = false } +thiserror = { workspace = true } + +rodio = { workspace = true } +rubato = "0.16.2" diff --git a/crates/audio-utils/src/error.rs b/crates/audio-utils/src/error.rs new file mode 100644 index 000000000..900aeb1bf --- /dev/null +++ b/crates/audio-utils/src/error.rs @@ -0,0 +1,7 @@ +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error(transparent)] + ResampleError(#[from] rubato::ResampleError), + #[error(transparent)] + ResamplerConstructionError(#[from] rubato::ResamplerConstructionError), +} diff --git a/crates/audio-utils/src/lib.rs b/crates/audio-utils/src/lib.rs index 749c2d5cc..581637c7e 100644 --- a/crates/audio-utils/src/lib.rs +++ b/crates/audio-utils/src/lib.rs @@ -2,6 +2,9 @@ use bytes::{BufMut, Bytes, BytesMut}; use futures_util::{Stream, StreamExt}; use kalosm_sound::AsyncSource; +mod error; +pub use error::*; + impl AudioFormatExt for T {} pub trait AudioFormatExt: AsyncSource { @@ -42,3 +45,54 @@ pub fn f32_to_i16_samples(samples: &[f32]) -> Vec { }) .collect() } + +pub fn resample_audio(source: S, to_rate: u32) -> Result, crate::Error> +where + S: rodio::Source + Iterator, + T: rodio::Sample, +{ + use rubato::{ + Resampler, SincFixedIn, SincInterpolationParameters, SincInterpolationType, WindowFunction, + }; + + let from_rate = source.sample_rate() as f64; + let channels = source.channels() as usize; + let to_rate_f64 = to_rate as f64; + + let samples: Vec = source.map(|sample| sample.to_f32()).collect(); + + if (from_rate - to_rate_f64).abs() < 1.0 { + return Ok(samples); + } + + let params = SincInterpolationParameters { + sinc_len: 256, + f_cutoff: 0.95, + interpolation: SincInterpolationType::Linear, + oversampling_factor: 256, + window: WindowFunction::BlackmanHarris2, + }; + + let mut resampler = + SincFixedIn::::new(to_rate_f64 / from_rate, 2.0, params, 1024, channels)?; + + let frames_per_channel = samples.len() / channels; + let mut input_channels: Vec> = vec![Vec::with_capacity(frames_per_channel); channels]; + + for (i, &sample) in samples.iter().enumerate() { + input_channels[i % channels].push(sample); + } + + let output_channels = resampler.process(&input_channels, None)?; + + let mut output = Vec::new(); + let output_frames = output_channels[0].len(); + + for frame in 0..output_frames { + for ch in 0..channels { + output.push(output_channels[ch][frame]); + } + } + + Ok(output) +} diff --git a/crates/audio/Cargo.toml b/crates/audio/Cargo.toml index 02414be36..bfe6a8e05 100644 --- a/crates/audio/Cargo.toml +++ b/crates/audio/Cargo.toml @@ -17,7 +17,7 @@ futures-util = { workspace = true } tokio = { workspace = true, features = ["rt", "macros"] } cpal = { workspace = true } -rodio = { workspace = true, features = ["vorbis"] } +rodio = { workspace = true } ebur128 = "0.1.10" kalosm-sound = { workspace = true, default-features = false } diff --git a/crates/chunker/Cargo.toml b/crates/chunker/Cargo.toml index fcf38d8f2..c0b0c8509 100644 --- a/crates/chunker/Cargo.toml +++ b/crates/chunker/Cargo.toml @@ -10,7 +10,7 @@ hypr-data = { workspace = true } [dependencies] hypr-vad = { workspace = true } kalosm-sound = { workspace = true, default-features = false } -rodio = { workspace = true, features = ["wav"] } +rodio = { workspace = true } futures-util = { workspace = true } serde = { workspace = true } diff --git a/plugins/local-stt/Cargo.toml b/plugins/local-stt/Cargo.toml index 377900a63..cfe853af0 100644 --- a/plugins/local-stt/Cargo.toml +++ b/plugins/local-stt/Cargo.toml @@ -28,7 +28,6 @@ hypr-language = { workspace = true } tauri-plugin-listener = { workspace = true } -kalosm-common = { workspace = true } tokio-tungstenite = { workspace = true } bytes = { workspace = true } @@ -52,13 +51,14 @@ tauri-plugin-task = { workspace = true } tauri-specta = { workspace = true, features = ["derive", "typescript"] } dirs = { workspace = true } -rodio = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } specta = { workspace = true } strum = { workspace = true, features = ["derive"] } thiserror = { workspace = true } +rodio = { workspace = true, features = ["symphonia", "symphonia-all"] } + axum = { workspace = true, features = ["ws", "multipart"] } tower-http = { workspace = true, features = ["cors", "trace"] } @@ -66,12 +66,3 @@ futures-util = { workspace = true } tokio = { workspace = true, features = ["rt", "macros"] } tokio-util = { workspace = true } tracing = { workspace = true } - -[target.'cfg(not(target_os = "macos"))'.dependencies] -kalosm-sound = { workspace = true, default-features = false } - -[target.'cfg(all(target_os = "macos", target_arch = "aarch64"))'.dependencies] -kalosm-sound = { workspace = true, default-features = false, features = ["metal"] } - -[target.'cfg(all(target_os = "macos", target_arch = "x86_64"))'.dependencies] -kalosm-sound = { workspace = true, default-features = false } diff --git a/plugins/local-stt/build.rs b/plugins/local-stt/build.rs index 431d00046..b184a8c17 100644 --- a/plugins/local-stt/build.rs +++ b/plugins/local-stt/build.rs @@ -10,6 +10,7 @@ const COMMANDS: &[&str] = &[ "get_current_model", "set_current_model", "list_supported_models", + "process_recorded", ]; fn main() { diff --git a/plugins/local-stt/js/bindings.gen.ts b/plugins/local-stt/js/bindings.gen.ts index c99ddb286..b1ca3d956 100644 --- a/plugins/local-stt/js/bindings.gen.ts +++ b/plugins/local-stt/js/bindings.gen.ts @@ -42,6 +42,9 @@ async stopServer() : Promise { }, async restartServer() : Promise { return await TAURI_INVOKE("plugin:local-stt|restart_server"); +}, +async processRecorded(audioPath: string) : Promise { + return await TAURI_INVOKE("plugin:local-stt|process_recorded", { audioPath }); } } @@ -61,8 +64,10 @@ recordedProcessingEvent: "plugin:local-stt:recorded-processing-event" /** user-defined types **/ export type GgmlBackend = { kind: string; name: string; description: string; total_memory_mb: number; free_memory_mb: number } -export type RecordedProcessingEvent = { type: "inactive"; current: number; total: number } +export type RecordedProcessingEvent = { type: "progress"; current: number; total: number; word: Word } +export type SpeakerIdentity = { type: "unassigned"; value: { index: number } } | { type: "assigned"; value: { id: string; label: string } } export type SupportedModel = "QuantizedTiny" | "QuantizedTinyEn" | "QuantizedBase" | "QuantizedBaseEn" | "QuantizedSmall" | "QuantizedSmallEn" | "QuantizedLargeTurbo" +export type Word = { text: string; speaker: SpeakerIdentity | null; confidence: number | null; start_ms: number | null; end_ms: number | null } /** tauri-specta globals **/ diff --git a/plugins/local-stt/permissions/autogenerated/commands/process_recorded.toml b/plugins/local-stt/permissions/autogenerated/commands/process_recorded.toml new file mode 100644 index 000000000..c5b3e018a --- /dev/null +++ b/plugins/local-stt/permissions/autogenerated/commands/process_recorded.toml @@ -0,0 +1,13 @@ +# Automatically generated - DO NOT EDIT! + +"$schema" = "../../schemas/schema.json" + +[[permission]] +identifier = "allow-process-recorded" +description = "Enables the process_recorded command without any pre-configured scope." +commands.allow = ["process_recorded"] + +[[permission]] +identifier = "deny-process-recorded" +description = "Denies the process_recorded command without any pre-configured scope." +commands.deny = ["process_recorded"] diff --git a/plugins/local-stt/permissions/autogenerated/reference.md b/plugins/local-stt/permissions/autogenerated/reference.md index 905ce3205..c5252e2c1 100644 --- a/plugins/local-stt/permissions/autogenerated/reference.md +++ b/plugins/local-stt/permissions/autogenerated/reference.md @@ -14,6 +14,7 @@ Default permissions for the plugin - `allow-get-current-model` - `allow-set-current-model` - `allow-list-supported-models` +- `allow-process-recorded` ## Permission Table @@ -261,6 +262,32 @@ Denies the models_dir command without any pre-configured scope. +`local-stt:allow-process-recorded` + + + + +Enables the process_recorded command without any pre-configured scope. + + + + + + + +`local-stt:deny-process-recorded` + + + + +Denies the process_recorded command without any pre-configured scope. + + + + + + + `local-stt:allow-set-current-model` diff --git a/plugins/local-stt/permissions/default.toml b/plugins/local-stt/permissions/default.toml index d256ed12d..0a1f4bb03 100644 --- a/plugins/local-stt/permissions/default.toml +++ b/plugins/local-stt/permissions/default.toml @@ -11,4 +11,5 @@ permissions = [ "allow-get-current-model", "allow-set-current-model", "allow-list-supported-models", + "allow-process-recorded", ] diff --git a/plugins/local-stt/permissions/schemas/schema.json b/plugins/local-stt/permissions/schemas/schema.json index b8fe6e193..3506fc10c 100644 --- a/plugins/local-stt/permissions/schemas/schema.json +++ b/plugins/local-stt/permissions/schemas/schema.json @@ -402,6 +402,18 @@ "const": "deny-models-dir", "markdownDescription": "Denies the models_dir command without any pre-configured scope." }, + { + "description": "Enables the process_recorded command without any pre-configured scope.", + "type": "string", + "const": "allow-process-recorded", + "markdownDescription": "Enables the process_recorded command without any pre-configured scope." + }, + { + "description": "Denies the process_recorded command without any pre-configured scope.", + "type": "string", + "const": "deny-process-recorded", + "markdownDescription": "Denies the process_recorded command without any pre-configured scope." + }, { "description": "Enables the set_current_model command without any pre-configured scope.", "type": "string", @@ -439,10 +451,10 @@ "markdownDescription": "Denies the stop_server command without any pre-configured scope." }, { - "description": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-models-dir`\n- `allow-is-server-running`\n- `allow-is-model-downloaded`\n- `allow-is-model-downloading`\n- `allow-download-model`\n- `allow-start-server`\n- `allow-stop-server`\n- `allow-get-current-model`\n- `allow-set-current-model`\n- `allow-list-supported-models`", + "description": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-models-dir`\n- `allow-is-server-running`\n- `allow-is-model-downloaded`\n- `allow-is-model-downloading`\n- `allow-download-model`\n- `allow-start-server`\n- `allow-stop-server`\n- `allow-get-current-model`\n- `allow-set-current-model`\n- `allow-list-supported-models`\n- `allow-process-recorded`", "type": "string", "const": "default", - "markdownDescription": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-models-dir`\n- `allow-is-server-running`\n- `allow-is-model-downloaded`\n- `allow-is-model-downloading`\n- `allow-download-model`\n- `allow-start-server`\n- `allow-stop-server`\n- `allow-get-current-model`\n- `allow-set-current-model`\n- `allow-list-supported-models`" + "markdownDescription": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-models-dir`\n- `allow-is-server-running`\n- `allow-is-model-downloaded`\n- `allow-is-model-downloading`\n- `allow-download-model`\n- `allow-start-server`\n- `allow-stop-server`\n- `allow-get-current-model`\n- `allow-set-current-model`\n- `allow-list-supported-models`\n- `allow-process-recorded`" } ] } diff --git a/plugins/local-stt/src/commands.rs b/plugins/local-stt/src/commands.rs index 73ffe3039..826379a9a 100644 --- a/plugins/local-stt/src/commands.rs +++ b/plugins/local-stt/src/commands.rs @@ -1,6 +1,8 @@ -use crate::LocalSttPluginExt; - use tauri::ipc::Channel; +use tauri_specta::Event; + +use crate::LocalSttPluginExt; +use tauri_plugin_task::TaskPluginExt; #[tauri::command] #[specta::specta] @@ -95,3 +97,24 @@ pub async fn restart_server(app: tauri::AppHandle) -> Resu app.stop_server().await.map_err(|e| e.to_string())?; app.start_server().await.map_err(|e| e.to_string()) } + +#[tauri::command] +#[specta::specta] +pub fn process_recorded( + app: tauri::AppHandle, + audio_path: String, +) -> Result<(), String> { + let current_model = app.get_current_model().map_err(|e| e.to_string())?; + let model_path = app.models_dir().join(current_model.file_name()); + + let app_clone = app.clone(); + app.spawn_task_blocking(move |_ctx| { + let app_clone_inner = app_clone.clone(); + let _ = app_clone + .process_recorded(model_path, audio_path, move |event| { + let _ = crate::events::RecordedProcessingEvent::emit(&event, &app_clone_inner); + }) + .map_err(|e| e.to_string()); + }); + Ok(()) +} diff --git a/plugins/local-stt/src/events.rs b/plugins/local-stt/src/events.rs index 6dfd4862b..7436d0bff 100644 --- a/plugins/local-stt/src/events.rs +++ b/plugins/local-stt/src/events.rs @@ -1,7 +1,7 @@ #[macro_export] macro_rules! common_event_derives { ($item:item) => { - #[derive(serde::Serialize, Clone, specta::Type, tauri_specta::Event)] + #[derive(Debug, Clone, serde::Serialize, specta::Type, tauri_specta::Event)] $item }; } @@ -10,6 +10,6 @@ common_event_derives! { #[serde(tag = "type")] pub enum RecordedProcessingEvent { #[serde(rename = "progress")] - Progress { current: usize, total: usize }, + Progress { current: usize, total: usize, word: hypr_listener_interface::Word }, } } diff --git a/plugins/local-stt/src/ext.rs b/plugins/local-stt/src/ext.rs index 8ed8c9a8e..3cc40ed58 100644 --- a/plugins/local-stt/src/ext.rs +++ b/plugins/local-stt/src/ext.rs @@ -1,14 +1,13 @@ use std::{future::Future, path::PathBuf}; -use futures_util::StreamExt; -use kalosm_sound::AsyncSource; - use tauri::{ipc::Channel, Manager, Runtime}; use tauri_plugin_store2::StorePluginExt; use hypr_file::{download_file_with_callback, DownloadProgress}; use hypr_listener_interface::Word; +use crate::events::RecordedProcessingEvent; + pub trait LocalSttPluginExt { fn local_stt_store(&self) -> tauri_plugin_store2::ScopedStore; fn models_dir(&self) -> PathBuf; @@ -24,7 +23,8 @@ pub trait LocalSttPluginExt { &self, model_path: impl AsRef, audio_path: impl AsRef, - ) -> impl Future, crate::Error>>; + progress_fn: impl FnMut(RecordedProcessingEvent) + Send + 'static, + ) -> Result, crate::Error>; fn download_model( &self, @@ -172,20 +172,28 @@ impl> LocalSttPluginExt for T { } #[tracing::instrument(skip_all)] - async fn process_recorded( + fn process_recorded( &self, model_path: impl AsRef, audio_path: impl AsRef, + mut progress_fn: impl FnMut(RecordedProcessingEvent) + Send + 'static, ) -> Result, crate::Error> { - let samples_f32: Vec = rodio::Decoder::new(std::io::BufReader::new( + use rodio::Source; + + let decoder = rodio::Decoder::new(std::io::BufReader::new( std::fs::File::open(audio_path.as_ref()).unwrap(), )) - .unwrap() - .resample(16000) - .collect::>() - .await; + .unwrap(); + + let original_sample_rate = decoder.sample_rate(); + + let resampled_samples = if original_sample_rate != 16000 { + hypr_audio_utils::resample_audio(decoder, 16000).unwrap() + } else { + decoder.convert_samples().collect() + }; - let samples_i16 = hypr_audio_utils::f32_to_i16_samples(&samples_f32); + let samples_i16 = hypr_audio_utils::f32_to_i16_samples(&resampled_samples); let mut model = hypr_whisper_local::Whisper::builder() .model_path(model_path.as_ref().to_str().unwrap()) @@ -195,6 +203,7 @@ impl> LocalSttPluginExt for T { let mut segmenter = hypr_pyannote_local::segmentation::Segmenter::new(16000).unwrap(); let segments = segmenter.process(&samples_i16, 16000).unwrap(); + let num_segments = segments.len(); let mut words = Vec::new(); @@ -209,12 +218,18 @@ impl> LocalSttPluginExt for T { let start_ms = (start_sec * 1000.0) as u64; let end_ms = (end_sec * 1000.0) as u64; - words.push(Word { + let word = Word { text: whisper_segment.text().to_string(), speaker: None, confidence: Some(whisper_segment.confidence()), start_ms: Some(start_ms), end_ms: Some(end_ms), + }; + words.push(word.clone()); + progress_fn(RecordedProcessingEvent::Progress { + current: words.len(), + total: num_segments, + word, }); } } diff --git a/plugins/local-stt/src/lib.rs b/plugins/local-stt/src/lib.rs index 1362df6df..0e52ba528 100644 --- a/plugins/local-stt/src/lib.rs +++ b/plugins/local-stt/src/lib.rs @@ -45,6 +45,7 @@ fn make_specta_builder() -> tauri_specta::Builder { commands::start_server::, commands::stop_server::, commands::restart_server::, + commands::process_recorded::, ]) .events(tauri_specta::collect_events![ events::RecordedProcessingEvent @@ -164,12 +165,13 @@ mod test { let model_path = dirs::data_dir() .unwrap() - .join("com.hyprnote.dev") + .join("com.hyprnote.dev/stt") .join("ggml-tiny.en-q8_0.bin"); let words = app - .process_recorded(model_path, hypr_data::english_1::AUDIO_PATH) - .await + .process_recorded(model_path, hypr_data::english_1::AUDIO_PATH, |event| { + println!("{:?}", event); + }) .unwrap(); println!("{:?}", words); diff --git a/plugins/task/src/commands.rs b/plugins/task/src/commands.rs index 7fb43d87c..74cae2525 100644 --- a/plugins/task/src/commands.rs +++ b/plugins/task/src/commands.rs @@ -6,7 +6,8 @@ pub async fn get_task( app: tauri::AppHandle, id: String, ) -> Result { - app.get_task(id).ok_or("not found".into()) + app.get_task(id) + .ok_or(crate::Error::TaskNotFound.to_string()) } #[tauri::command] @@ -15,5 +16,6 @@ pub async fn cancel_task( app: tauri::AppHandle, id: String, ) -> Result<(), String> { - app.cancel_task(id).map_err(|_| "not found".into()) + app.cancel_task(id) + .map_err(|_| crate::Error::TaskNotFound.to_string()) } diff --git a/plugins/task/src/ctx.rs b/plugins/task/src/ctx.rs index 5f69ef1a7..7160bda74 100644 --- a/plugins/task/src/ctx.rs +++ b/plugins/task/src/ctx.rs @@ -16,11 +16,11 @@ pub struct TaskCtx { } impl TaskCtx { - pub fn new(id: String, total: u32, store: ScopedStore) -> Self { + pub fn new(id: String, store: ScopedStore) -> Self { Self { id, current: 0, - total, + total: 1, store, cancelled: Arc::new(AtomicBool::new(false)), } @@ -60,6 +60,14 @@ impl TaskCtx { }) } + pub fn complete(&self) -> Result<(), crate::Error> { + self.update_status(TaskStatus::Completed) + } + + pub fn fail(&self, error: String) -> Result<(), crate::Error> { + self.update_status(TaskStatus::Failed { error }) + } + fn update_status(&self, status: TaskStatus) -> Result<(), crate::Error> { let id = self.id.clone(); diff --git a/plugins/task/src/error.rs b/plugins/task/src/error.rs index dcb1682ab..0e9c9ab7a 100644 --- a/plugins/task/src/error.rs +++ b/plugins/task/src/error.rs @@ -4,6 +4,8 @@ use serde::{ser::Serializer, Serialize}; pub enum Error { #[error("Store operation failed")] StoreError, + #[error("Task not found")] + TaskNotFound, } impl Serialize for Error { diff --git a/plugins/task/src/ext.rs b/plugins/task/src/ext.rs index e9fe94300..f92ccae6d 100644 --- a/plugins/task/src/ext.rs +++ b/plugins/task/src/ext.rs @@ -1,5 +1,3 @@ -use std::future::Future; - use tauri::{AppHandle, Manager, Runtime}; use tauri_plugin_store2::{ScopedStore, StorePluginExt}; @@ -8,10 +6,10 @@ use crate::{StoreKey, TaskCtx, TaskRecord, TaskState, TaskStatus}; pub trait TaskPluginExt: Manager { fn task_store(&self) -> ScopedStore; - fn spawn_task(&self, total_steps: u32, exec: F) -> String + fn spawn_task_blocking(&self, exec: F) -> String where - F: Fn(TaskCtx) -> Fut + Send + Sync + 'static, - Fut: Future> + Send + 'static; + F: FnOnce(TaskCtx) -> Fut + Send + 'static, + Fut: Send + 'static; fn get_task(&self, id: String) -> Option; fn cancel_task(&self, id: String) -> Result<(), crate::Error>; @@ -22,13 +20,13 @@ impl> TaskPluginExt for T { self.scoped_store(crate::PLUGIN_NAME).unwrap() } - fn spawn_task(&self, total_steps: u32, exec: F) -> String + fn spawn_task_blocking(&self, exec: F) -> String where - F: Fn(TaskCtx) -> Fut + Send + Sync + 'static, - Fut: Future> + Send + 'static, + F: FnOnce(TaskCtx) -> Fut + Send + 'static, + Fut: Send + 'static, { let id = uuid::Uuid::new_v4().to_string(); - let ctx = TaskCtx::new(id.clone(), total_steps, self.task_store()); + let ctx = TaskCtx::new(id.clone(), self.task_store()); let task_state: tauri::State = self.state(); let app_handle: AppHandle = self.app_handle().clone(); @@ -40,7 +38,7 @@ impl> TaskPluginExt for T { id: id.clone(), status: TaskStatus::Running { current: 0, - total: total_steps, + total: 1, }, data: std::collections::HashMap::new(), }; @@ -50,12 +48,9 @@ impl> TaskPluginExt for T { .set(StoreKey::Tasks(id.clone()), initial_record); let task_id = id.clone(); - tauri::async_runtime::spawn(async move { - // Execute the task - the implementation is responsible for calling - // ctx.complete() or ctx.fail() to update the final status - let _ = exec(ctx).await; + tauri::async_runtime::spawn_blocking(move || { + let _ = exec(ctx); - // Remove from active tasks if let Some(state) = app_handle.try_state::() { state.remove_task(&task_id); } @@ -71,9 +66,7 @@ impl> TaskPluginExt for T { fn cancel_task(&self, id: String) -> Result<(), crate::Error> { let task_state: tauri::State = self.state(); - // Set the cancellation flag if task_state.cancel_task(&id) { - // Update the task status in store if let Some(mut record) = self.get_task(id.clone()) { record.status = TaskStatus::Cancelled; self.task_store() @@ -82,7 +75,6 @@ impl> TaskPluginExt for T { } Ok(()) } else { - // Task not found or already completed Ok(()) } } diff --git a/plugins/task/src/lib.rs b/plugins/task/src/lib.rs index 466a7109d..e168e186a 100644 --- a/plugins/task/src/lib.rs +++ b/plugins/task/src/lib.rs @@ -61,4 +61,9 @@ mod test { .build(tauri::test::mock_context(tauri::test::noop_assets())) .unwrap() } + + #[test] + fn test_task() { + let _app = create_app(tauri::test::mock_builder()); + } } diff --git a/plugins/task/src/store.rs b/plugins/task/src/store.rs index e3c8e61e6..f6b54827d 100644 --- a/plugins/task/src/store.rs +++ b/plugins/task/src/store.rs @@ -17,6 +17,19 @@ pub struct TaskRecord { pub data: std::collections::HashMap, } +impl Default for TaskRecord { + fn default() -> Self { + Self { + id: String::new(), + status: TaskStatus::Running { + current: 0, + total: 1, + }, + data: std::collections::HashMap::new(), + } + } +} + #[derive(Deserialize, specta::Type, PartialEq, Eq, Hash, strum::Display)] pub enum StoreKey { Tasks(String),