From 3d3d414ad13be864d7e4f7b5401a1da37844e3bf Mon Sep 17 00:00:00 2001 From: Yujong Lee Date: Thu, 25 Sep 2025 11:01:21 +0900 Subject: [PATCH 1/4] initial attempt to wrap cli as actor --- Cargo.lock | 2 + Cargo.toml | 3 +- plugins/listener/Cargo.toml | 2 +- plugins/local-stt/Cargo.toml | 2 + plugins/local-stt/src/ext.rs | 46 +++- plugins/local-stt/src/lib.rs | 2 +- plugins/local-stt/src/server/external.rs | 262 ++++++++++++++--------- 7 files changed, 203 insertions(+), 116 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6b07ab9b1..719fdeac6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -14548,6 +14548,7 @@ dependencies = [ "audio-utils", "axum 0.8.4", "axum-extra", + "backon", "data", "dirs 6.0.0", "download-interface", @@ -14561,6 +14562,7 @@ dependencies = [ "owhisper-interface", "port-killer", "port_check", + "ractor", "reqwest 0.12.23", "rodio", "serde", diff --git a/Cargo.toml b/Cargo.toml index c3aab0ff5..590ac45a2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -129,6 +129,7 @@ async-stream = "0.3.6" futures-channel = "0.3.31" futures-core = "0.3.31" futures-util = "0.3.31" +ractor = "0.15" reqwest = "0.12" reqwest-streams = "0.10.0" tokio = "1" @@ -138,7 +139,7 @@ tokio-util = "0.7.15" anyhow = "1" approx = "0.5.1" -backon = "1.4.1" +backon = "1.5.2" bytes = "1.9.0" cached = "0.55.1" clap = "4" diff --git a/plugins/listener/Cargo.toml b/plugins/listener/Cargo.toml index 6c0c1fb12..d3f70a321 100644 --- a/plugins/listener/Cargo.toml +++ b/plugins/listener/Cargo.toml @@ -61,7 +61,7 @@ hound = { workspace = true } vorbis_rs = { workspace = true } futures-util = { workspace = true } -ractor = "0.15" +ractor = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tokio-stream = { workspace = true } tokio-util = { workspace = true } diff --git a/plugins/local-stt/Cargo.toml b/plugins/local-stt/Cargo.toml index 111ca2e85..a1e1cc093 100644 --- a/plugins/local-stt/Cargo.toml +++ b/plugins/local-stt/Cargo.toml @@ -74,7 +74,9 @@ axum = { workspace = true, features = ["ws"] } axum-extra = { workspace = true, features = ["query"] } tower-http = { workspace = true, features = ["cors", "trace"] } +backon = { workspace = true } futures-util = { workspace = true } +ractor = { workspace = true } reqwest = { workspace = true } tokio = { workspace = true, features = ["rt", "macros"] } tokio-util = { workspace = true } diff --git a/plugins/local-stt/src/ext.rs b/plugins/local-stt/src/ext.rs index 4c8d6b654..53e3902e6 100644 --- a/plugins/local-stt/src/ext.rs +++ b/plugins/local-stt/src/ext.rs @@ -1,5 +1,6 @@ use std::{collections::HashMap, future::Future, path::PathBuf}; +use ractor::{call_t, Actor}; use tauri::{ipc::Channel, Manager, Runtime}; use tauri_plugin_shell::ShellExt; use tauri_plugin_store2::StorePluginExt; @@ -11,7 +12,10 @@ use tokio_util::sync::CancellationToken; use crate::{ model::SupportedSttModel, - server::{external, internal, ServerHealth, ServerType}, + server::{ + external::{self, ExternalSTTMessage}, + internal, ServerHealth, ServerType, + }, Connection, Provider, StoreKey, }; @@ -150,7 +154,16 @@ impl> LocalSttPluginExt for T { let existing_api_base = { let state = self.state::(); let guard = state.lock().await; - guard.external_server.as_ref().map(|s| s.base_url.clone()) + + match &guard.external_server { + Some(server) => { + let (base_url, _) = + call_t!(server, ExternalSTTMessage::GetHealth, 10 * 1000) + .unwrap(); + Some(base_url) + } + None => None, + } }; let am_key = { @@ -349,10 +362,21 @@ impl> LocalSttPluginExt for T { .args(["serve"]) }; - let server = external::run_server(cmd, am_key).await?; - tokio::time::sleep(std::time::Duration::from_millis(250)).await; - let _ = server.init(am_model, data_dir).await; - let api_base = server.base_url.clone(); + let (server, _) = Actor::spawn( + Some("external_stt".to_string()), + external::ExternalSTTActor, + external::ExternalSTTArgs { + cmd, + api_key: am_key, + model: am_model, + models_dir: data_dir, + }, + ) + .await + .unwrap(); + + let (base_url, _) = + call_t!(server, ExternalSTTMessage::GetHealth, 10 * 1000).unwrap(); { let state = self.state::(); @@ -360,7 +384,7 @@ impl> LocalSttPluginExt for T { s.external_server = Some(server); } - Ok(api_base) + Ok(base_url) } } } @@ -379,9 +403,8 @@ impl> LocalSttPluginExt for T { let mut stopped = false; match server_type { Some(ServerType::External) => { - hypr_host::kill_processes_by_matcher(hypr_host::ProcessMatcher::Sidecar); - - if let Some(_) = s.external_server.take() { + if let Some(actor) = s.external_server.take() { + actor.stop(None); stopped = true; } } @@ -417,7 +440,8 @@ impl> LocalSttPluginExt for T { }; let external_health = if let Some(server) = &guard.external_server { - server.health().await + let (_, status) = call_t!(server, ExternalSTTMessage::GetHealth, 10 * 1000).unwrap(); + status } else { ServerHealth::Unreachable }; diff --git a/plugins/local-stt/src/lib.rs b/plugins/local-stt/src/lib.rs index 8883af529..524d3482a 100644 --- a/plugins/local-stt/src/lib.rs +++ b/plugins/local-stt/src/lib.rs @@ -24,7 +24,7 @@ pub type SharedState = std::sync::Arc>; pub struct State { pub am_api_key: Option, pub internal_server: Option, - pub external_server: Option, + pub external_server: Option>, pub download_task: HashMap, CancellationToken)>, } diff --git a/plugins/local-stt/src/server/external.rs b/plugins/local-stt/src/server/external.rs index 1b0fa9e52..4537d5b03 100644 --- a/plugins/local-stt/src/server/external.rs +++ b/plugins/local-stt/src/server/external.rs @@ -1,125 +1,183 @@ +use std::path::PathBuf; +use tauri_plugin_shell::process::{Command, CommandChild}; + use super::ServerHealth; +use backon::{ConstantBuilder, Retryable}; +use ractor::{Actor, ActorProcessingErr, ActorRef, RpcReplyPort}; -pub struct ServerHandle { - pub base_url: String, - api_key: Option, - shutdown: tokio::sync::watch::Sender<()>, - client: hypr_am::Client, +#[derive(Debug)] +pub enum ExternalSTTMessage { + GetHealth(RpcReplyPort<(String, ServerHealth)>), } -impl Drop for ServerHandle { - fn drop(&mut self) { - tracing::info!("stopping"); - let _ = self.shutdown.send(()); - } +pub struct ExternalSTTArgs { + pub cmd: Command, + pub api_key: String, + pub model: hypr_am::AmModel, + pub models_dir: PathBuf, } -impl ServerHandle { - pub async fn health(&self) -> ServerHealth { - let res = self.client.status().await; - if res.is_err() { - tracing::error!("{:?}", res); - return ServerHealth::Unreachable; - } - - let res = res.unwrap(); - - if res.model_state == hypr_am::ModelState::Loading { - return ServerHealth::Loading; - } +pub struct ExternalSTTState { + base_url: String, + api_key: Option, + model: hypr_am::AmModel, + models_dir: PathBuf, + client: hypr_am::Client, + process_handle: Option, + task_handle: Option>, +} - if res.model_state == hypr_am::ModelState::Loaded { - return ServerHealth::Ready; - } +pub struct ExternalSTTActor; - ServerHealth::Unreachable - } +impl Actor for ExternalSTTActor { + type Msg = ExternalSTTMessage; + type State = ExternalSTTState; + type Arguments = ExternalSTTArgs; - pub async fn init( + async fn pre_start( &self, - model: hypr_am::AmModel, - models_dir: impl AsRef, - ) -> Result { - let r = self - .client - .init( - hypr_am::InitRequest::new(self.api_key.clone().unwrap()) - .with_model(model, models_dir), - ) - .await?; - - Ok(r) - } -} + _myself: ActorRef, + args: Self::Arguments, + ) -> Result { + let port = port_check::free_local_port().unwrap(); + let (mut rx, child) = args.cmd.args(["--port", &port.to_string()]).spawn()?; + let base_url = format!("http://localhost:{}", port); + let client = hypr_am::Client::new(&base_url); -pub async fn run_server( - cmd: tauri_plugin_shell::process::Command, - am_key: String, -) -> Result { - let port = port_check::free_local_port().unwrap(); - let (mut rx, child) = cmd.args(["--port", &port.to_string()]).spawn()?; - let base_url = format!("http://localhost:{}", port); - let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(()); - let client = hypr_am::Client::new(&base_url); - - tokio::spawn(async move { - let mut process_ended = false; - - loop { - tokio::select! { - _ = shutdown_rx.changed() => { - tracing::info!("shutdown_signal_received"); - break; - } - event = rx.recv() => { - match event { - Some(tauri_plugin_shell::process::CommandEvent::Stdout(bytes)) => { - if let Ok(text) = String::from_utf8(bytes) { - let text = text.trim(); - if !text.is_empty() && !text.contains("[TranscriptionHandler]") && !text.contains("[WebSocket]") && !text.contains("Sent interim") { - tracing::info!("{}", text); - } + let task_handle = tokio::spawn(async move { + loop { + match rx.recv().await { + Some(tauri_plugin_shell::process::CommandEvent::Stdout(bytes)) => { + if let Ok(text) = String::from_utf8(bytes) { + let text = text.trim(); + if !text.is_empty() + && !text.contains("[TranscriptionHandler]") + && !text.contains("[WebSocket]") + && !text.contains("Sent interim") + { + tracing::info!("{}", text); } } - Some(tauri_plugin_shell::process::CommandEvent::Stderr(bytes)) => { - if let Ok(text) = String::from_utf8(bytes) { - let text = text.trim(); - if !text.is_empty() && !text.contains("[TranscriptionHandler]") && !text.contains("[WebSocket]") && !text.contains("Sent interim") { - tracing::info!("{}", text); - } + } + Some(tauri_plugin_shell::process::CommandEvent::Stderr(bytes)) => { + if let Ok(text) = String::from_utf8(bytes) { + let text = text.trim(); + if !text.is_empty() + && !text.contains("[TranscriptionHandler]") + && !text.contains("[WebSocket]") + && !text.contains("Sent interim") + { + tracing::info!("{}", text); } } - Some(tauri_plugin_shell::process::CommandEvent::Terminated(payload)) => { - tracing::error!("terminated: {:?}", payload); - process_ended = true; - break; - } - Some(tauri_plugin_shell::process::CommandEvent::Error(error)) => { - tracing::error!("{}", error); - break; - } - None => { - tracing::warn!("closed"); - process_ended = true; - break; - } - _ => {} } + Some(tauri_plugin_shell::process::CommandEvent::Terminated(payload)) => { + tracing::error!("terminated: {:?}", payload); + break; + } + Some(tauri_plugin_shell::process::CommandEvent::Error(error)) => { + tracing::error!("{}", error); + break; + } + None => { + tracing::warn!("closed"); + break; + } + _ => {} } } + }); + + Ok(ExternalSTTState { + base_url, + api_key: Some(args.api_key), + model: args.model, + models_dir: args.models_dir, + client, + process_handle: Some(child), + task_handle: Some(task_handle), + }) + } + async fn post_start( + &self, + _myself: ActorRef, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + let api_key = state.api_key.clone().unwrap(); + let model = state.model.clone(); + let models_dir = state.models_dir.clone(); + + let res = (|| async { + state + .client + .init( + hypr_am::InitRequest::new(api_key.clone()) + .with_model(model.clone(), &models_dir), + ) + .await + }) + .retry( + ConstantBuilder::default() + .with_max_times(20) + .with_delay(std::time::Duration::from_millis(500)), + ) + .when(|e| { + tracing::error!("external_stt_init_failed: {:?}", e); + true + }) + .sleep(tokio::time::sleep) + .await?; + + tracing::info!(res = ?res); + Ok(()) + } + + async fn post_stop( + &self, + _myself: ActorRef, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + if let Some(process) = state.process_handle.take() { + if let Err(e) = process.kill() { + tracing::error!("failed_to_kill_process: {:?}", e); + } } - if !process_ended { - if let Err(e) = child.kill() { - tracing::error!("{:?}", e); + if let Some(task) = state.task_handle.take() { + task.abort(); + } + + hypr_host::kill_processes_by_matcher(hypr_host::ProcessMatcher::Sidecar); + + Ok(()) + } + + async fn handle( + &self, + _myself: ActorRef, + message: Self::Msg, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match message { + ExternalSTTMessage::GetHealth(reply_port) => { + let status = match state.client.status().await { + Ok(r) => match r.model_state { + hypr_am::ModelState::Loading => ServerHealth::Loading, + hypr_am::ModelState::Loaded => ServerHealth::Ready, + _ => ServerHealth::Unreachable, + }, + Err(e) => { + tracing::error!("{:?}", e); + ServerHealth::Unreachable + } + }; + + if let Err(e) = reply_port.send((state.base_url.clone(), status)) { + tracing::error!("{:?}", e); + } + + Ok(()) } } - }); - - Ok(ServerHandle { - api_key: Some(am_key), - base_url, - shutdown: shutdown_tx, - client, - }) + } } From 490334b850eed10bb235ce9f6a2822ca5cd1a2e4 Mon Sep 17 00:00:00 2001 From: Yujong Lee Date: Thu, 25 Sep 2025 13:46:14 +0900 Subject: [PATCH 2/4] kind of done --- crates/transcribe-whisper-local/src/lib.rs | 9 +- .../src/service/streaming.rs | 2 +- plugins/listener/src/actors/listener.rs | 28 +- plugins/local-stt/src/ext.rs | 145 ++++++----- plugins/local-stt/src/lib.rs | 3 +- plugins/local-stt/src/server/external.rs | 26 +- plugins/local-stt/src/server/internal.rs | 241 +++++++----------- plugins/local-stt/src/server/mod.rs | 2 + 8 files changed, 224 insertions(+), 232 deletions(-) diff --git a/crates/transcribe-whisper-local/src/lib.rs b/crates/transcribe-whisper-local/src/lib.rs index a8c2ef352..05c2bee37 100644 --- a/crates/transcribe-whisper-local/src/lib.rs +++ b/crates/transcribe-whisper-local/src/lib.rs @@ -8,6 +8,7 @@ pub use service::*; // cargo test -p transcribe-whisper-local test_service -- --nocapture mod tests { use super::*; + use axum::{error_handling::HandleError, http::StatusCode}; use futures_util::StreamExt; use hypr_audio_utils::AudioFormatExt; @@ -18,7 +19,10 @@ mod tests { .join("com.hyprnote.dev") .join("stt/ggml-small-q8_0.bin"); - let service = TranscribeService::builder().model_path(model_path).build(); + let service = HandleError::new( + TranscribeService::builder().model_path(model_path).build(), + move |err: String| async move { (StatusCode::INTERNAL_SERVER_ERROR, err) }, + ); let app = axum::Router::new().route_service("/v1/listen", service); @@ -43,8 +47,7 @@ mod tests { .to_i16_le_chunks(16000, 512); let input = audio.map(|chunk| owhisper_interface::MixedMessage::Audio(chunk)); - let stream = client.from_realtime_audio(input).await.unwrap(); - futures_util::pin_mut!(stream); + let _ = client.from_realtime_audio(input).await.unwrap(); server_handle.abort(); Ok(()) diff --git a/crates/transcribe-whisper-local/src/service/streaming.rs b/crates/transcribe-whisper-local/src/service/streaming.rs index cb8b061f9..796701c17 100644 --- a/crates/transcribe-whisper-local/src/service/streaming.rs +++ b/crates/transcribe-whisper-local/src/service/streaming.rs @@ -62,7 +62,7 @@ where B: Send + 'static, { type Response = Response; - type Error = std::convert::Infallible; + type Error = String; type Future = Pin> + Send>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { diff --git a/plugins/listener/src/actors/listener.rs b/plugins/listener/src/actors/listener.rs index 52b76795e..7f44036b0 100644 --- a/plugins/listener/src/actors/listener.rs +++ b/plugins/listener/src/actors/listener.rs @@ -5,7 +5,7 @@ use bytes::Bytes; use futures_util::StreamExt; use owhisper_interface::{ControlMessage, MixedMessage, Word2}; -use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef}; +use ractor::{pg, Actor, ActorName, ActorProcessingErr, ActorRef, SupervisionEvent}; use tauri_specta::Event; use crate::{manager::TranscriptManager, SessionEvent}; @@ -47,10 +47,20 @@ impl Actor for ListenerActor { myself: ActorRef, args: Self::Arguments, ) -> Result { + pg::monitor(tauri_plugin_local_stt::GROUP.into(), myself.get_cell()); let (tx, rx_task) = spawn_rx_task(args, myself).await.unwrap(); Ok(ListenerState { tx, rx_task }) } + async fn post_stop( + &self, + _myself: ActorRef, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + state.rx_task.abort(); + Ok(()) + } + async fn handle( &self, _myself: ActorRef, @@ -65,12 +75,20 @@ impl Actor for ListenerActor { Ok(()) } - async fn post_stop( + async fn handle_supervisor_evt( &self, - _myself: ActorRef, - state: &mut Self::State, + myself: ActorRef, + message: SupervisionEvent, + _state: &mut Self::State, ) -> Result<(), ActorProcessingErr> { - state.rx_task.abort(); + tracing::info!("supervisor_event: {:?}", message); + + match message { + SupervisionEvent::ActorStarted(_) | SupervisionEvent::ProcessGroupChanged(_) => {} + SupervisionEvent::ActorFailed(_, _) | SupervisionEvent::ActorTerminated(_, _, _) => { + myself.stop(None) + } + } Ok(()) } } diff --git a/plugins/local-stt/src/ext.rs b/plugins/local-stt/src/ext.rs index 53e3902e6..75ed4308a 100644 --- a/plugins/local-stt/src/ext.rs +++ b/plugins/local-stt/src/ext.rs @@ -1,6 +1,6 @@ use std::{collections::HashMap, future::Future, path::PathBuf}; -use ractor::{call_t, Actor}; +use ractor::{call_t, registry, Actor, ActorRef}; use tauri::{ipc::Channel, Manager, Runtime}; use tauri_plugin_shell::ShellExt; use tauri_plugin_store2::StorePluginExt; @@ -152,17 +152,19 @@ impl> LocalSttPluginExt for T { } SupportedSttModel::Am(_) => { let existing_api_base = { - let state = self.state::(); - let guard = state.lock().await; - - match &guard.external_server { - Some(server) => { - let (base_url, _) = - call_t!(server, ExternalSTTMessage::GetHealth, 10 * 1000) - .unwrap(); - Some(base_url) - } - None => None, + if let Some(cell) = + registry::where_is(external::ExternalSTTActor::name()) + { + let actor: ActorRef = cell.into(); + let (base_url, _) = call_t!( + actor, + external::ExternalSTTMessage::GetHealth, + 10 * 1000 + ) + .unwrap(); + Some(base_url) + } else { + None } }; @@ -190,11 +192,20 @@ impl> LocalSttPluginExt for T { Ok(conn) } SupportedSttModel::Whisper(_) => { - let existing_api_base = { - let state = self.state::(); - let guard = state.lock().await; - guard.internal_server.as_ref().map(|s| s.base_url.clone()) - }; + let existing_api_base = + match registry::where_is(internal::InternalSTTActor::name()) { + Some(cell) => { + let actor: ActorRef = cell.into(); + let (base_url, _) = call_t!( + actor, + internal::InternalSTTMessage::GetHealth, + 10 * 1000 + ) + .unwrap(); + Some(base_url) + } + None => None, + }; let conn = match existing_api_base { Some(api_base) => Connection { @@ -272,13 +283,7 @@ impl> LocalSttPluginExt for T { return Err(crate::Error::ModelNotDownloaded); } - if self - .state::() - .lock() - .await - .internal_server - .is_some() - { + if registry::where_is(internal::InternalSTTActor::name()).is_some() { return Err(crate::Error::ServerAlreadyRunning); } @@ -289,31 +294,24 @@ impl> LocalSttPluginExt for T { } }; - let server_state = internal::ServerState::builder() - .model_cache_dir(cache_dir) - .model_type(whisper_model) - .build(); - - let server = internal::run_server(server_state).await?; - let base_url = server.base_url.clone(); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; + let (server, _) = Actor::spawn( + Some(internal::InternalSTTActor::name()), + internal::InternalSTTActor, + internal::InternalSTTArgs { + model_cache_dir: cache_dir, + model_type: whisper_model, + }, + ) + .await + .unwrap(); - { - let state = self.state::(); - let mut s = state.lock().await; - s.internal_server = Some(server); - } + let (base_url, _) = + call_t!(server, internal::InternalSTTMessage::GetHealth, 10 * 1000).unwrap(); Ok(base_url) } ServerType::External => { - if self - .state::() - .lock() - .await - .external_server - .is_some() - { + if registry::where_is(external::ExternalSTTActor::name()).is_some() { return Err(crate::Error::ServerAlreadyRunning); } @@ -363,7 +361,7 @@ impl> LocalSttPluginExt for T { }; let (server, _) = Actor::spawn( - Some("external_stt".to_string()), + Some(external::ExternalSTTActor::name()), external::ExternalSTTActor, external::ExternalSTTArgs { cmd, @@ -378,12 +376,6 @@ impl> LocalSttPluginExt for T { let (base_url, _) = call_t!(server, ExternalSTTMessage::GetHealth, 10 * 1000).unwrap(); - { - let state = self.state::(); - let mut s = state.lock().await; - s.external_server = Some(server); - } - Ok(base_url) } } @@ -397,28 +389,32 @@ impl> LocalSttPluginExt for T { return Ok(false); } - let state = self.state::(); - let mut s = state.lock().await; - let mut stopped = false; match server_type { Some(ServerType::External) => { - if let Some(actor) = s.external_server.take() { + if let Some(cell) = registry::where_is(external::ExternalSTTActor::name()) { + let actor: ActorRef = cell.into(); actor.stop(None); stopped = true; } } Some(ServerType::Internal) => { - if let Some(_) = s.internal_server.take() { + if let Some(cell) = registry::where_is(internal::InternalSTTActor::name()) { + let actor: ActorRef = cell.into(); + actor.stop(None); stopped = true; } } Some(ServerType::Custom) => {} None => { - if let Some(_) = s.external_server.take() { + if let Some(cell) = registry::where_is(external::ExternalSTTActor::name()) { + let actor: ActorRef = cell.into(); + actor.stop(None); stopped = true; } - if let Some(_) = s.internal_server.take() { + if let Some(cell) = registry::where_is(internal::InternalSTTActor::name()) { + let actor: ActorRef = cell.into(); + actor.stop(None); stopped = true; } } @@ -429,22 +425,25 @@ impl> LocalSttPluginExt for T { #[tracing::instrument(skip_all)] async fn get_servers(&self) -> Result, crate::Error> { - let state = self.state::(); - let guard = state.lock().await; - - let internal_health = if let Some(server) = &guard.internal_server { - let status = server.health().await; - status - } else { - ServerHealth::Unreachable - }; + let internal_health = + if let Some(cell) = registry::where_is(internal::InternalSTTActor::name()) { + let actor: ActorRef = cell.into(); + let (_, status) = + call_t!(actor, internal::InternalSTTMessage::GetHealth, 10 * 1000).unwrap(); + status + } else { + ServerHealth::Unreachable + }; - let external_health = if let Some(server) = &guard.external_server { - let (_, status) = call_t!(server, ExternalSTTMessage::GetHealth, 10 * 1000).unwrap(); - status - } else { - ServerHealth::Unreachable - }; + let external_health = + if let Some(cell) = registry::where_is(external::ExternalSTTActor::name()) { + let actor: ActorRef = cell.into(); + let (_, status) = + call_t!(actor, external::ExternalSTTMessage::GetHealth, 10 * 1000).unwrap(); + status + } else { + ServerHealth::Unreachable + }; let custom_health = { let provider = self.get_provider()?; diff --git a/plugins/local-stt/src/lib.rs b/plugins/local-stt/src/lib.rs index 524d3482a..e09d1639d 100644 --- a/plugins/local-stt/src/lib.rs +++ b/plugins/local-stt/src/lib.rs @@ -15,6 +15,7 @@ pub use error::*; use events::*; pub use ext::*; pub use model::*; +pub use server::GROUP; pub use store::*; pub use types::*; @@ -23,8 +24,6 @@ pub type SharedState = std::sync::Arc>; #[derive(Default)] pub struct State { pub am_api_key: Option, - pub internal_server: Option, - pub external_server: Option>, pub download_task: HashMap, CancellationToken)>, } diff --git a/plugins/local-stt/src/server/external.rs b/plugins/local-stt/src/server/external.rs index 4537d5b03..2663c8cf4 100644 --- a/plugins/local-stt/src/server/external.rs +++ b/plugins/local-stt/src/server/external.rs @@ -3,11 +3,11 @@ use tauri_plugin_shell::process::{Command, CommandChild}; use super::ServerHealth; use backon::{ConstantBuilder, Retryable}; -use ractor::{Actor, ActorProcessingErr, ActorRef, RpcReplyPort}; +use ractor::{pg, Actor, ActorName, ActorProcessingErr, ActorRef, RpcReplyPort}; -#[derive(Debug)] pub enum ExternalSTTMessage { GetHealth(RpcReplyPort<(String, ServerHealth)>), + ProcessTerminated(String), } pub struct ExternalSTTArgs { @@ -29,6 +29,12 @@ pub struct ExternalSTTState { pub struct ExternalSTTActor; +impl ExternalSTTActor { + pub fn name() -> ActorName { + "external_stt".into() + } +} + impl Actor for ExternalSTTActor { type Msg = ExternalSTTMessage; type State = ExternalSTTState; @@ -36,9 +42,11 @@ impl Actor for ExternalSTTActor { async fn pre_start( &self, - _myself: ActorRef, + myself: ActorRef, args: Self::Arguments, ) -> Result { + pg::join(super::GROUP.into(), vec![myself.get_cell()]); + let port = port_check::free_local_port().unwrap(); let (mut rx, child) = args.cmd.args(["--port", &port.to_string()]).spawn()?; let base_url = format!("http://localhost:{}", port); @@ -72,11 +80,14 @@ impl Actor for ExternalSTTActor { } } Some(tauri_plugin_shell::process::CommandEvent::Terminated(payload)) => { - tracing::error!("terminated: {:?}", payload); + let e = format!("{:?}", payload); + tracing::error!("{}", e); + let _ = myself.send_message(ExternalSTTMessage::ProcessTerminated(e)); break; } Some(tauri_plugin_shell::process::CommandEvent::Error(error)) => { tracing::error!("{}", error); + let _ = myself.send_message(ExternalSTTMessage::ProcessTerminated(error)); break; } None => { @@ -134,9 +145,11 @@ impl Actor for ExternalSTTActor { async fn post_stop( &self, - _myself: ActorRef, + myself: ActorRef, state: &mut Self::State, ) -> Result<(), ActorProcessingErr> { + pg::leave(super::GROUP.into(), vec![myself.get_cell()]); + if let Some(process) = state.process_handle.take() { if let Err(e) = process.kill() { tracing::error!("failed_to_kill_process: {:?}", e); @@ -159,6 +172,7 @@ impl Actor for ExternalSTTActor { state: &mut Self::State, ) -> Result<(), ActorProcessingErr> { match message { + ExternalSTTMessage::ProcessTerminated(e) => Err(e.into()), ExternalSTTMessage::GetHealth(reply_port) => { let status = match state.client.status().await { Ok(r) => match r.model_state { @@ -173,7 +187,7 @@ impl Actor for ExternalSTTActor { }; if let Err(e) = reply_port.send((state.base_url.clone(), status)) { - tracing::error!("{:?}", e); + return Err(e.into()); } Ok(()) diff --git a/plugins/local-stt/src/server/internal.rs b/plugins/local-stt/src/server/internal.rs index d85823a89..b9bee2dea 100644 --- a/plugins/local-stt/src/server/internal.rs +++ b/plugins/local-stt/src/server/internal.rs @@ -3,166 +3,123 @@ use std::{ path::PathBuf, }; -use axum::{http::StatusCode, response::IntoResponse, routing::get, Router}; +use axum::{error_handling::HandleError, Router}; +use ractor::{pg, Actor, ActorName, ActorProcessingErr, ActorRef, RpcReplyPort}; +use reqwest::StatusCode; use tower_http::cors::{self, CorsLayer}; use super::ServerHealth; use hypr_whisper_local_model::WhisperModel; -#[derive(Default)] -pub struct ServerStateBuilder { - pub model_type: Option, - pub model_cache_dir: Option, +pub enum InternalSTTMessage { + GetHealth(RpcReplyPort<(String, ServerHealth)>), + ServerError(String), } -impl ServerStateBuilder { - pub fn model_cache_dir(mut self, model_cache_dir: PathBuf) -> Self { - self.model_cache_dir = Some(model_cache_dir); - self - } - - pub fn model_type(mut self, model_type: WhisperModel) -> Self { - self.model_type = Some(model_type); - self - } - - pub fn build(self) -> ServerState { - ServerState { - model_type: self.model_type.unwrap(), - model_cache_dir: self.model_cache_dir.unwrap(), - } - } +pub struct InternalSTTArgs { + pub model_type: WhisperModel, + pub model_cache_dir: PathBuf, } -#[derive(Clone)] -pub struct ServerState { - model_type: WhisperModel, - model_cache_dir: PathBuf, +pub struct InternalSTTState { + base_url: String, + shutdown: tokio::sync::watch::Sender<()>, + server_task: tokio::task::JoinHandle<()>, } -impl ServerState { - pub fn builder() -> ServerStateBuilder { - ServerStateBuilder::default() +pub struct InternalSTTActor; + +impl InternalSTTActor { + pub fn name() -> ActorName { + "internal_stt".into() } } -#[derive(Clone)] -pub struct ServerHandle { - pub base_url: String, - pub api_key: Option, - shutdown: tokio::sync::watch::Sender<()>, -} +impl Actor for InternalSTTActor { + type Msg = InternalSTTMessage; + type State = InternalSTTState; + type Arguments = InternalSTTArgs; + + async fn pre_start( + &self, + myself: ActorRef, + args: Self::Arguments, + ) -> Result { + pg::join(super::GROUP.into(), vec![myself.get_cell()]); + + let model_path = args.model_cache_dir.join(args.model_type.file_name()); + + let whisper_service = HandleError::new( + hypr_transcribe_whisper_local::TranscribeService::builder() + .model_path(model_path) + .build(), + move |err: String| async move { + let _ = myself.send_message(InternalSTTMessage::ServerError(err.clone())); + (StatusCode::INTERNAL_SERVER_ERROR, err) + }, + ); + + let router = Router::new() + .route_service("/v1/listen", whisper_service) + .layer( + CorsLayer::new() + .allow_origin(cors::Any) + .allow_methods(cors::Any) + .allow_headers(cors::Any), + ); + + let listener = + tokio::net::TcpListener::bind(SocketAddr::from((Ipv4Addr::LOCALHOST, 0))).await?; + + let server_addr = listener.local_addr()?; + let base_url = format!("http://{}", server_addr); + + let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(()); + + let server_task = tokio::spawn(async move { + axum::serve(listener, router) + .with_graceful_shutdown(async move { + shutdown_rx.changed().await.ok(); + }) + .await + .unwrap(); + }); + + Ok(InternalSTTState { + base_url, + shutdown: shutdown_tx, + server_task, + }) + } -impl Drop for ServerHandle { - fn drop(&mut self) { - tracing::info!("stopping: {}", self.base_url); - let _ = self.shutdown.send(()); + async fn post_stop( + &self, + myself: ActorRef, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + pg::leave(super::GROUP.into(), vec![myself.get_cell()]); + let _ = state.shutdown.send(()); + state.server_task.abort(); + Ok(()) } -} -impl ServerHandle { - pub async fn health(&self) -> ServerHealth { - let client = reqwest::Client::new(); - let url = format!("{}/health", self.base_url); - - match client - .get(&url) - .timeout(std::time::Duration::from_secs(2)) - .send() - .await - { - Ok(res) if res.status().is_success() => ServerHealth::Ready, - Ok(_res) => ServerHealth::Unreachable, - Err(e) => { - tracing::error!("{:?}", e); - ServerHealth::Unreachable + async fn handle( + &self, + _myself: ActorRef, + message: Self::Msg, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match message { + InternalSTTMessage::GetHealth(reply_port) => { + let status = ServerHealth::Ready; + + if let Err(e) = reply_port.send((state.base_url.clone(), status)) { + return Err(e.into()); + } + + Ok(()) } + InternalSTTMessage::ServerError(e) => Err(e.into()), } } } - -pub async fn run_server(state: ServerState) -> Result { - tracing::info!("starting"); - let router = make_service_router(state); - - let listener = - tokio::net::TcpListener::bind(SocketAddr::from((Ipv4Addr::LOCALHOST, 0))).await?; - - let server_addr = listener.local_addr()?; - let base_url = format!("http://{}", server_addr); - - let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(()); - - let server_handle = ServerHandle { - base_url, - api_key: None, - shutdown: shutdown_tx, - }; - - tokio::spawn(async move { - axum::serve(listener, router) - .with_graceful_shutdown(async move { - shutdown_rx.changed().await.ok(); - }) - .await - .unwrap(); - }); - - tracing::info!("local_stt_server_started {}", server_addr); - Ok(server_handle) -} - -fn make_service_router(state: ServerState) -> Router { - let model_path = state.model_cache_dir.join(state.model_type.file_name()); - - let whisper_service = hypr_transcribe_whisper_local::TranscribeService::builder() - .model_path(model_path) - .build(); - - Router::new() - .route("/health", get(health)) - .route_service("/v1/listen", whisper_service) - .layer( - CorsLayer::new() - .allow_origin(cors::Any) - .allow_methods(cors::Any) - .allow_headers(cors::Any), - ) -} - -async fn health() -> impl IntoResponse { - StatusCode::OK -} - -#[cfg(test)] -mod tests { - use super::*; - - use axum::body::Body; - use axum::http::{Request, StatusCode}; - use tower::ServiceExt; - - use hypr_whisper_local_model::WhisperModel; - - #[tokio::test] - async fn test_health_endpoint() { - let state = ServerStateBuilder::default() - .model_cache_dir(dirs::data_dir().unwrap().join("com.hyprnote.dev/stt")) - .model_type(WhisperModel::QuantizedTinyEn) - .build(); - - let app = make_service_router(state); - - let response = app - .oneshot( - Request::builder() - .uri("/health") - .body(Body::empty()) - .unwrap(), - ) - .await - .unwrap(); - - assert_eq!(response.status(), StatusCode::OK); - } -} diff --git a/plugins/local-stt/src/server/mod.rs b/plugins/local-stt/src/server/mod.rs index 2383cbb37..8e3bd1faa 100644 --- a/plugins/local-stt/src/server/mod.rs +++ b/plugins/local-stt/src/server/mod.rs @@ -1,6 +1,8 @@ pub mod external; pub mod internal; +pub const GROUP: &str = "stt"; + #[derive( Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize, specta::Type, )] From ea94baa16bfd012fbe693ca83e70f6cfdb4aef8d Mon Sep 17 00:00:00 2001 From: Yujong Lee Date: Thu, 25 Sep 2025 15:25:42 +0900 Subject: [PATCH 3/4] wip --- plugins/listener/src/actors/listener.rs | 59 +++++++--- plugins/local-stt/src/ext.rs | 132 +++++++++++------------ plugins/local-stt/src/lib.rs | 2 +- plugins/local-stt/src/server/external.rs | 24 +---- plugins/local-stt/src/server/internal.rs | 7 +- plugins/local-stt/src/server/mod.rs | 2 - 6 files changed, 113 insertions(+), 113 deletions(-) diff --git a/plugins/listener/src/actors/listener.rs b/plugins/listener/src/actors/listener.rs index 7f44036b0..bfb7b7975 100644 --- a/plugins/listener/src/actors/listener.rs +++ b/plugins/listener/src/actors/listener.rs @@ -5,7 +5,7 @@ use bytes::Bytes; use futures_util::StreamExt; use owhisper_interface::{ControlMessage, MixedMessage, Word2}; -use ractor::{pg, Actor, ActorName, ActorProcessingErr, ActorRef, SupervisionEvent}; +use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef, SupervisionEvent}; use tauri_specta::Event; use crate::{manager::TranscriptManager, SessionEvent}; @@ -27,6 +27,7 @@ pub struct ListenerArgs { pub struct ListenerState { tx: tokio::sync::mpsc::Sender>, rx_task: tokio::task::JoinHandle<()>, + shutdown_tx: Option>, } pub struct ListenerActor; @@ -47,9 +48,19 @@ impl Actor for ListenerActor { myself: ActorRef, args: Self::Arguments, ) -> Result { - pg::monitor(tauri_plugin_local_stt::GROUP.into(), myself.get_cell()); - let (tx, rx_task) = spawn_rx_task(args, myself).await.unwrap(); - Ok(ListenerState { tx, rx_task }) + { + use tauri_plugin_local_stt::LocalSttPluginExt; + let _ = args.app.start_server(None).await; + } + + let (tx, rx_task, shutdown_tx) = spawn_rx_task(args, myself).await.unwrap(); + let state = ListenerState { + tx, + rx_task, + shutdown_tx: Some(shutdown_tx), + }; + + Ok(state) } async fn post_stop( @@ -57,6 +68,9 @@ impl Actor for ListenerActor { _myself: ActorRef, state: &mut Self::State, ) -> Result<(), ActorProcessingErr> { + if let Some(shutdown_tx) = state.shutdown_tx.take() { + let _ = shutdown_tx.send(()); + } state.rx_task.abort(); Ok(()) } @@ -85,8 +99,9 @@ impl Actor for ListenerActor { match message { SupervisionEvent::ActorStarted(_) | SupervisionEvent::ProcessGroupChanged(_) => {} - SupervisionEvent::ActorFailed(_, _) | SupervisionEvent::ActorTerminated(_, _, _) => { - myself.stop(None) + SupervisionEvent::ActorTerminated(_, _, _) => {} + SupervisionEvent::ActorFailed(_cell, _) => { + myself.stop(None); } } Ok(()) @@ -100,10 +115,12 @@ async fn spawn_rx_task( ( tokio::sync::mpsc::Sender>, tokio::task::JoinHandle<()>, + tokio::sync::oneshot::Sender<()>, ), ActorProcessingErr, > { let (tx, rx) = tokio::sync::mpsc::channel::>(32); + let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>(); let app = args.app.clone(); let session_id = args.session_id.clone(); @@ -127,7 +144,7 @@ async fn spawn_rx_task( let rx_task = tokio::spawn(async move { let outbound = tokio_stream::wrappers::ReceiverStream::new(rx); - let (listen_stream, _handle) = match client.from_realtime_audio(outbound).await { + let (listen_stream, handle) = match client.from_realtime_audio(outbound).await { Ok(res) => res, Err(e) => { tracing::error!("listen_ws_connect_failed: {:?}", e); @@ -140,8 +157,14 @@ async fn spawn_rx_task( let mut manager = TranscriptManager::with_unix_timestamp(session_start_ts_ms); loop { - match tokio::time::timeout(LISTEN_STREAM_TIMEOUT, listen_stream.next()).await { - Ok(Some(response)) => { + tokio::select! { + _ = &mut shutdown_rx => { + handle.finalize_with_text(serde_json::json!({"type": "Finalize"}).to_string().into()).await; + break; + } + result = tokio::time::timeout(LISTEN_STREAM_TIMEOUT, listen_stream.next()) => { + match result { + Ok(Some(response)) => { let diff = manager.append(response.clone()); let partial_words_by_channel: HashMap> = diff @@ -197,13 +220,15 @@ async fn spawn_rx_task( .emit(&app) .unwrap(); } - Ok(None) => { - tracing::info!("listen_stream_ended"); - break; - } - Err(_) => { - tracing::info!("listen_stream_timeout"); - break; + Ok(None) => { + tracing::info!("listen_stream_ended"); + break; + } + Err(_) => { + tracing::info!("listen_stream_timeout"); + break; + } + } } } } @@ -211,7 +236,7 @@ async fn spawn_rx_task( myself.stop(None); }); - Ok((tx, rx_task)) + Ok((tx, rx_task, shutdown_tx)) } async fn update_session( diff --git a/plugins/local-stt/src/ext.rs b/plugins/local-stt/src/ext.rs index 75ed4308a..5bd885b77 100644 --- a/plugins/local-stt/src/ext.rs +++ b/plugins/local-stt/src/ext.rs @@ -1,6 +1,8 @@ use std::{collections::HashMap, future::Future, path::PathBuf}; use ractor::{call_t, registry, Actor, ActorRef}; +use tokio_util::sync::CancellationToken; + use tauri::{ipc::Channel, Manager, Runtime}; use tauri_plugin_shell::ShellExt; use tauri_plugin_store2::StorePluginExt; @@ -8,14 +10,10 @@ use tauri_plugin_store2::StorePluginExt; use hypr_download_interface::DownloadProgress; use hypr_file::download_file_parallel_cancellable; use hypr_whisper_local_model::WhisperModel; -use tokio_util::sync::CancellationToken; use crate::{ model::SupportedSttModel, - server::{ - external::{self, ExternalSTTMessage}, - internal, ServerHealth, ServerType, - }, + server::{external, internal, ServerHealth, ServerType}, Connection, Provider, StoreKey, }; @@ -151,22 +149,7 @@ impl> LocalSttPluginExt for T { }) } SupportedSttModel::Am(_) => { - let existing_api_base = { - if let Some(cell) = - registry::where_is(external::ExternalSTTActor::name()) - { - let actor: ActorRef = cell.into(); - let (base_url, _) = call_t!( - actor, - external::ExternalSTTMessage::GetHealth, - 10 * 1000 - ) - .unwrap(); - Some(base_url) - } else { - None - } - }; + let existing_api_base = external_health().await.map(|r| r.0); let am_key = { let state = self.state::(); @@ -192,20 +175,7 @@ impl> LocalSttPluginExt for T { Ok(conn) } SupportedSttModel::Whisper(_) => { - let existing_api_base = - match registry::where_is(internal::InternalSTTActor::name()) { - Some(cell) => { - let actor: ActorRef = cell.into(); - let (base_url, _) = call_t!( - actor, - internal::InternalSTTMessage::GetHealth, - 10 * 1000 - ) - .unwrap(); - Some(base_url) - } - None => None, - }; + let existing_api_base = internal_health().await.map(|r| r.0); let conn = match existing_api_base { Some(api_base) => Connection { @@ -294,7 +264,7 @@ impl> LocalSttPluginExt for T { } }; - let (server, _) = Actor::spawn( + let (_server, _) = Actor::spawn( Some(internal::InternalSTTActor::name()), internal::InternalSTTActor, internal::InternalSTTArgs { @@ -305,9 +275,7 @@ impl> LocalSttPluginExt for T { .await .unwrap(); - let (base_url, _) = - call_t!(server, internal::InternalSTTMessage::GetHealth, 10 * 1000).unwrap(); - + let base_url = internal_health().await.map(|r| r.0).unwrap(); Ok(base_url) } ServerType::External => { @@ -360,7 +328,7 @@ impl> LocalSttPluginExt for T { .args(["serve"]) }; - let (server, _) = Actor::spawn( + let (_server, _) = Actor::spawn( Some(external::ExternalSTTActor::name()), external::ExternalSTTActor, external::ExternalSTTArgs { @@ -373,9 +341,7 @@ impl> LocalSttPluginExt for T { .await .unwrap(); - let (base_url, _) = - call_t!(server, ExternalSTTMessage::GetHealth, 10 * 1000).unwrap(); - + let base_url = external_health().await.map(|v| v.0).unwrap(); Ok(base_url) } } @@ -394,28 +360,40 @@ impl> LocalSttPluginExt for T { Some(ServerType::External) => { if let Some(cell) = registry::where_is(external::ExternalSTTActor::name()) { let actor: ActorRef = cell.into(); - actor.stop(None); - stopped = true; + if let Err(e) = actor.stop_and_wait(None, None).await { + tracing::error!("stop_server: {:?}", e); + } else { + stopped = true; + } } } Some(ServerType::Internal) => { if let Some(cell) = registry::where_is(internal::InternalSTTActor::name()) { let actor: ActorRef = cell.into(); - actor.stop(None); - stopped = true; + if let Err(e) = actor.stop_and_wait(None, None).await { + tracing::error!("stop_server: {:?}", e); + } else { + stopped = true; + } } } Some(ServerType::Custom) => {} None => { if let Some(cell) = registry::where_is(external::ExternalSTTActor::name()) { let actor: ActorRef = cell.into(); - actor.stop(None); - stopped = true; + if let Err(e) = actor.stop_and_wait(None, None).await { + tracing::error!("stop_server: {:?}", e); + } else { + stopped = true; + } } if let Some(cell) = registry::where_is(internal::InternalSTTActor::name()) { let actor: ActorRef = cell.into(); - actor.stop(None); - stopped = true; + if let Err(e) = actor.stop_and_wait(None, None).await { + tracing::error!("stop_server: {:?}", e); + } else { + stopped = true; + } } } } @@ -425,25 +403,15 @@ impl> LocalSttPluginExt for T { #[tracing::instrument(skip_all)] async fn get_servers(&self) -> Result, crate::Error> { - let internal_health = - if let Some(cell) = registry::where_is(internal::InternalSTTActor::name()) { - let actor: ActorRef = cell.into(); - let (_, status) = - call_t!(actor, internal::InternalSTTMessage::GetHealth, 10 * 1000).unwrap(); - status - } else { - ServerHealth::Unreachable - }; + let internal_health = internal_health() + .await + .map(|r| r.1) + .unwrap_or(ServerHealth::Unreachable); - let external_health = - if let Some(cell) = registry::where_is(external::ExternalSTTActor::name()) { - let actor: ActorRef = cell.into(); - let (_, status) = - call_t!(actor, external::ExternalSTTMessage::GetHealth, 10 * 1000).unwrap(); - status - } else { - ServerHealth::Unreachable - }; + let external_health = external_health() + .await + .map(|r| r.1) + .unwrap_or(ServerHealth::Unreachable); let custom_health = { let provider = self.get_provider()?; @@ -657,3 +625,29 @@ impl> LocalSttPluginExt for T { Ok(()) } } + +async fn internal_health() -> Option<(String, ServerHealth)> { + match registry::where_is(internal::InternalSTTActor::name()) { + Some(cell) => { + let actor: ActorRef = cell.into(); + match call_t!(actor, internal::InternalSTTMessage::GetHealth, 10 * 1000) { + Ok(r) => Some(r), + Err(_) => None, + } + } + None => None, + } +} + +async fn external_health() -> Option<(String, ServerHealth)> { + match registry::where_is(external::ExternalSTTActor::name()) { + Some(cell) => { + let actor: ActorRef = cell.into(); + match call_t!(actor, external::ExternalSTTMessage::GetHealth, 10 * 1000) { + Ok(r) => Some(r), + Err(_) => None, + } + } + None => None, + } +} diff --git a/plugins/local-stt/src/lib.rs b/plugins/local-stt/src/lib.rs index e09d1639d..11deece4c 100644 --- a/plugins/local-stt/src/lib.rs +++ b/plugins/local-stt/src/lib.rs @@ -15,7 +15,7 @@ pub use error::*; use events::*; pub use ext::*; pub use model::*; -pub use server::GROUP; +pub use server::*; pub use store::*; pub use types::*; diff --git a/plugins/local-stt/src/server/external.rs b/plugins/local-stt/src/server/external.rs index 2663c8cf4..39919d980 100644 --- a/plugins/local-stt/src/server/external.rs +++ b/plugins/local-stt/src/server/external.rs @@ -3,7 +3,7 @@ use tauri_plugin_shell::process::{Command, CommandChild}; use super::ServerHealth; use backon::{ConstantBuilder, Retryable}; -use ractor::{pg, Actor, ActorName, ActorProcessingErr, ActorRef, RpcReplyPort}; +use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef, RpcReplyPort}; pub enum ExternalSTTMessage { GetHealth(RpcReplyPort<(String, ServerHealth)>), @@ -45,8 +45,6 @@ impl Actor for ExternalSTTActor { myself: ActorRef, args: Self::Arguments, ) -> Result { - pg::join(super::GROUP.into(), vec![myself.get_cell()]); - let port = port_check::free_local_port().unwrap(); let (mut rx, child) = args.cmd.args(["--port", &port.to_string()]).spawn()?; let base_url = format!("http://localhost:{}", port); @@ -55,25 +53,15 @@ impl Actor for ExternalSTTActor { let task_handle = tokio::spawn(async move { loop { match rx.recv().await { - Some(tauri_plugin_shell::process::CommandEvent::Stdout(bytes)) => { - if let Ok(text) = String::from_utf8(bytes) { - let text = text.trim(); - if !text.is_empty() - && !text.contains("[TranscriptionHandler]") - && !text.contains("[WebSocket]") - && !text.contains("Sent interim") - { - tracing::info!("{}", text); - } - } - } - Some(tauri_plugin_shell::process::CommandEvent::Stderr(bytes)) => { + Some(tauri_plugin_shell::process::CommandEvent::Stdout(bytes)) + | Some(tauri_plugin_shell::process::CommandEvent::Stderr(bytes)) => { if let Ok(text) = String::from_utf8(bytes) { let text = text.trim(); if !text.is_empty() && !text.contains("[TranscriptionHandler]") && !text.contains("[WebSocket]") && !text.contains("Sent interim") + && !text.contains("/v1/status") { tracing::info!("{}", text); } @@ -145,11 +133,9 @@ impl Actor for ExternalSTTActor { async fn post_stop( &self, - myself: ActorRef, + _myself: ActorRef, state: &mut Self::State, ) -> Result<(), ActorProcessingErr> { - pg::leave(super::GROUP.into(), vec![myself.get_cell()]); - if let Some(process) = state.process_handle.take() { if let Err(e) = process.kill() { tracing::error!("failed_to_kill_process: {:?}", e); diff --git a/plugins/local-stt/src/server/internal.rs b/plugins/local-stt/src/server/internal.rs index b9bee2dea..8329523b5 100644 --- a/plugins/local-stt/src/server/internal.rs +++ b/plugins/local-stt/src/server/internal.rs @@ -4,7 +4,7 @@ use std::{ }; use axum::{error_handling::HandleError, Router}; -use ractor::{pg, Actor, ActorName, ActorProcessingErr, ActorRef, RpcReplyPort}; +use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef, RpcReplyPort}; use reqwest::StatusCode; use tower_http::cors::{self, CorsLayer}; @@ -45,8 +45,6 @@ impl Actor for InternalSTTActor { myself: ActorRef, args: Self::Arguments, ) -> Result { - pg::join(super::GROUP.into(), vec![myself.get_cell()]); - let model_path = args.model_cache_dir.join(args.model_type.file_name()); let whisper_service = HandleError::new( @@ -94,10 +92,9 @@ impl Actor for InternalSTTActor { async fn post_stop( &self, - myself: ActorRef, + _myself: ActorRef, state: &mut Self::State, ) -> Result<(), ActorProcessingErr> { - pg::leave(super::GROUP.into(), vec![myself.get_cell()]); let _ = state.shutdown.send(()); state.server_task.abort(); Ok(()) diff --git a/plugins/local-stt/src/server/mod.rs b/plugins/local-stt/src/server/mod.rs index 8e3bd1faa..2383cbb37 100644 --- a/plugins/local-stt/src/server/mod.rs +++ b/plugins/local-stt/src/server/mod.rs @@ -1,8 +1,6 @@ pub mod external; pub mod internal; -pub const GROUP: &str = "stt"; - #[derive( Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize, specta::Type, )] From d289294c19fed1dae384ac7e0ef4f5dcda358841 Mon Sep 17 00:00:00 2001 From: Yujong Lee Date: Thu, 25 Sep 2025 15:47:20 +0900 Subject: [PATCH 4/4] wip --- plugins/local-stt/src/server/external.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/plugins/local-stt/src/server/external.rs b/plugins/local-stt/src/server/external.rs index 39919d980..a473846b7 100644 --- a/plugins/local-stt/src/server/external.rs +++ b/plugins/local-stt/src/server/external.rs @@ -153,12 +153,15 @@ impl Actor for ExternalSTTActor { async fn handle( &self, - _myself: ActorRef, + myself: ActorRef, message: Self::Msg, state: &mut Self::State, ) -> Result<(), ActorProcessingErr> { match message { - ExternalSTTMessage::ProcessTerminated(e) => Err(e.into()), + ExternalSTTMessage::ProcessTerminated(e) => { + myself.stop(Some(e)); + Ok(()) + } ExternalSTTMessage::GetHealth(reply_port) => { let status = match state.client.status().await { Ok(r) => match r.model_state {