diff --git a/Cargo.lock b/Cargo.lock index 6b07ab9b1..719fdeac6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -14548,6 +14548,7 @@ dependencies = [ "audio-utils", "axum 0.8.4", "axum-extra", + "backon", "data", "dirs 6.0.0", "download-interface", @@ -14561,6 +14562,7 @@ dependencies = [ "owhisper-interface", "port-killer", "port_check", + "ractor", "reqwest 0.12.23", "rodio", "serde", diff --git a/Cargo.toml b/Cargo.toml index c3aab0ff5..590ac45a2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -129,6 +129,7 @@ async-stream = "0.3.6" futures-channel = "0.3.31" futures-core = "0.3.31" futures-util = "0.3.31" +ractor = "0.15" reqwest = "0.12" reqwest-streams = "0.10.0" tokio = "1" @@ -138,7 +139,7 @@ tokio-util = "0.7.15" anyhow = "1" approx = "0.5.1" -backon = "1.4.1" +backon = "1.5.2" bytes = "1.9.0" cached = "0.55.1" clap = "4" diff --git a/crates/transcribe-whisper-local/src/lib.rs b/crates/transcribe-whisper-local/src/lib.rs index a8c2ef352..05c2bee37 100644 --- a/crates/transcribe-whisper-local/src/lib.rs +++ b/crates/transcribe-whisper-local/src/lib.rs @@ -8,6 +8,7 @@ pub use service::*; // cargo test -p transcribe-whisper-local test_service -- --nocapture mod tests { use super::*; + use axum::{error_handling::HandleError, http::StatusCode}; use futures_util::StreamExt; use hypr_audio_utils::AudioFormatExt; @@ -18,7 +19,10 @@ mod tests { .join("com.hyprnote.dev") .join("stt/ggml-small-q8_0.bin"); - let service = TranscribeService::builder().model_path(model_path).build(); + let service = HandleError::new( + TranscribeService::builder().model_path(model_path).build(), + move |err: String| async move { (StatusCode::INTERNAL_SERVER_ERROR, err) }, + ); let app = axum::Router::new().route_service("/v1/listen", service); @@ -43,8 +47,7 @@ mod tests { .to_i16_le_chunks(16000, 512); let input = audio.map(|chunk| owhisper_interface::MixedMessage::Audio(chunk)); - let stream = client.from_realtime_audio(input).await.unwrap(); - futures_util::pin_mut!(stream); + let _ = client.from_realtime_audio(input).await.unwrap(); server_handle.abort(); Ok(()) diff --git a/crates/transcribe-whisper-local/src/service/streaming.rs b/crates/transcribe-whisper-local/src/service/streaming.rs index cb8b061f9..796701c17 100644 --- a/crates/transcribe-whisper-local/src/service/streaming.rs +++ b/crates/transcribe-whisper-local/src/service/streaming.rs @@ -62,7 +62,7 @@ where B: Send + 'static, { type Response = Response; - type Error = std::convert::Infallible; + type Error = String; type Future = Pin> + Send>>; fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { diff --git a/plugins/listener/Cargo.toml b/plugins/listener/Cargo.toml index 6c0c1fb12..d3f70a321 100644 --- a/plugins/listener/Cargo.toml +++ b/plugins/listener/Cargo.toml @@ -61,7 +61,7 @@ hound = { workspace = true } vorbis_rs = { workspace = true } futures-util = { workspace = true } -ractor = "0.15" +ractor = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } tokio-stream = { workspace = true } tokio-util = { workspace = true } diff --git a/plugins/listener/src/actors/listener.rs b/plugins/listener/src/actors/listener.rs index 52b76795e..bfb7b7975 100644 --- a/plugins/listener/src/actors/listener.rs +++ b/plugins/listener/src/actors/listener.rs @@ -5,7 +5,7 @@ use bytes::Bytes; use futures_util::StreamExt; use owhisper_interface::{ControlMessage, MixedMessage, Word2}; -use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef}; +use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef, SupervisionEvent}; use tauri_specta::Event; use crate::{manager::TranscriptManager, SessionEvent}; @@ -27,6 +27,7 @@ pub struct ListenerArgs { pub struct ListenerState { tx: tokio::sync::mpsc::Sender>, rx_task: tokio::task::JoinHandle<()>, + shutdown_tx: Option>, } pub struct ListenerActor; @@ -47,8 +48,31 @@ impl Actor for ListenerActor { myself: ActorRef, args: Self::Arguments, ) -> Result { - let (tx, rx_task) = spawn_rx_task(args, myself).await.unwrap(); - Ok(ListenerState { tx, rx_task }) + { + use tauri_plugin_local_stt::LocalSttPluginExt; + let _ = args.app.start_server(None).await; + } + + let (tx, rx_task, shutdown_tx) = spawn_rx_task(args, myself).await.unwrap(); + let state = ListenerState { + tx, + rx_task, + shutdown_tx: Some(shutdown_tx), + }; + + Ok(state) + } + + async fn post_stop( + &self, + _myself: ActorRef, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + if let Some(shutdown_tx) = state.shutdown_tx.take() { + let _ = shutdown_tx.send(()); + } + state.rx_task.abort(); + Ok(()) } async fn handle( @@ -65,12 +89,21 @@ impl Actor for ListenerActor { Ok(()) } - async fn post_stop( + async fn handle_supervisor_evt( &self, - _myself: ActorRef, - state: &mut Self::State, + myself: ActorRef, + message: SupervisionEvent, + _state: &mut Self::State, ) -> Result<(), ActorProcessingErr> { - state.rx_task.abort(); + tracing::info!("supervisor_event: {:?}", message); + + match message { + SupervisionEvent::ActorStarted(_) | SupervisionEvent::ProcessGroupChanged(_) => {} + SupervisionEvent::ActorTerminated(_, _, _) => {} + SupervisionEvent::ActorFailed(_cell, _) => { + myself.stop(None); + } + } Ok(()) } } @@ -82,10 +115,12 @@ async fn spawn_rx_task( ( tokio::sync::mpsc::Sender>, tokio::task::JoinHandle<()>, + tokio::sync::oneshot::Sender<()>, ), ActorProcessingErr, > { let (tx, rx) = tokio::sync::mpsc::channel::>(32); + let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>(); let app = args.app.clone(); let session_id = args.session_id.clone(); @@ -109,7 +144,7 @@ async fn spawn_rx_task( let rx_task = tokio::spawn(async move { let outbound = tokio_stream::wrappers::ReceiverStream::new(rx); - let (listen_stream, _handle) = match client.from_realtime_audio(outbound).await { + let (listen_stream, handle) = match client.from_realtime_audio(outbound).await { Ok(res) => res, Err(e) => { tracing::error!("listen_ws_connect_failed: {:?}", e); @@ -122,8 +157,14 @@ async fn spawn_rx_task( let mut manager = TranscriptManager::with_unix_timestamp(session_start_ts_ms); loop { - match tokio::time::timeout(LISTEN_STREAM_TIMEOUT, listen_stream.next()).await { - Ok(Some(response)) => { + tokio::select! { + _ = &mut shutdown_rx => { + handle.finalize_with_text(serde_json::json!({"type": "Finalize"}).to_string().into()).await; + break; + } + result = tokio::time::timeout(LISTEN_STREAM_TIMEOUT, listen_stream.next()) => { + match result { + Ok(Some(response)) => { let diff = manager.append(response.clone()); let partial_words_by_channel: HashMap> = diff @@ -179,13 +220,15 @@ async fn spawn_rx_task( .emit(&app) .unwrap(); } - Ok(None) => { - tracing::info!("listen_stream_ended"); - break; - } - Err(_) => { - tracing::info!("listen_stream_timeout"); - break; + Ok(None) => { + tracing::info!("listen_stream_ended"); + break; + } + Err(_) => { + tracing::info!("listen_stream_timeout"); + break; + } + } } } } @@ -193,7 +236,7 @@ async fn spawn_rx_task( myself.stop(None); }); - Ok((tx, rx_task)) + Ok((tx, rx_task, shutdown_tx)) } async fn update_session( diff --git a/plugins/local-stt/Cargo.toml b/plugins/local-stt/Cargo.toml index 111ca2e85..a1e1cc093 100644 --- a/plugins/local-stt/Cargo.toml +++ b/plugins/local-stt/Cargo.toml @@ -74,7 +74,9 @@ axum = { workspace = true, features = ["ws"] } axum-extra = { workspace = true, features = ["query"] } tower-http = { workspace = true, features = ["cors", "trace"] } +backon = { workspace = true } futures-util = { workspace = true } +ractor = { workspace = true } reqwest = { workspace = true } tokio = { workspace = true, features = ["rt", "macros"] } tokio-util = { workspace = true } diff --git a/plugins/local-stt/src/ext.rs b/plugins/local-stt/src/ext.rs index 4c8d6b654..5bd885b77 100644 --- a/plugins/local-stt/src/ext.rs +++ b/plugins/local-stt/src/ext.rs @@ -1,5 +1,8 @@ use std::{collections::HashMap, future::Future, path::PathBuf}; +use ractor::{call_t, registry, Actor, ActorRef}; +use tokio_util::sync::CancellationToken; + use tauri::{ipc::Channel, Manager, Runtime}; use tauri_plugin_shell::ShellExt; use tauri_plugin_store2::StorePluginExt; @@ -7,7 +10,6 @@ use tauri_plugin_store2::StorePluginExt; use hypr_download_interface::DownloadProgress; use hypr_file::download_file_parallel_cancellable; use hypr_whisper_local_model::WhisperModel; -use tokio_util::sync::CancellationToken; use crate::{ model::SupportedSttModel, @@ -147,11 +149,7 @@ impl> LocalSttPluginExt for T { }) } SupportedSttModel::Am(_) => { - let existing_api_base = { - let state = self.state::(); - let guard = state.lock().await; - guard.external_server.as_ref().map(|s| s.base_url.clone()) - }; + let existing_api_base = external_health().await.map(|r| r.0); let am_key = { let state = self.state::(); @@ -177,11 +175,7 @@ impl> LocalSttPluginExt for T { Ok(conn) } SupportedSttModel::Whisper(_) => { - let existing_api_base = { - let state = self.state::(); - let guard = state.lock().await; - guard.internal_server.as_ref().map(|s| s.base_url.clone()) - }; + let existing_api_base = internal_health().await.map(|r| r.0); let conn = match existing_api_base { Some(api_base) => Connection { @@ -259,13 +253,7 @@ impl> LocalSttPluginExt for T { return Err(crate::Error::ModelNotDownloaded); } - if self - .state::() - .lock() - .await - .internal_server - .is_some() - { + if registry::where_is(internal::InternalSTTActor::name()).is_some() { return Err(crate::Error::ServerAlreadyRunning); } @@ -276,31 +264,22 @@ impl> LocalSttPluginExt for T { } }; - let server_state = internal::ServerState::builder() - .model_cache_dir(cache_dir) - .model_type(whisper_model) - .build(); - - let server = internal::run_server(server_state).await?; - let base_url = server.base_url.clone(); - tokio::time::sleep(std::time::Duration::from_millis(100)).await; - - { - let state = self.state::(); - let mut s = state.lock().await; - s.internal_server = Some(server); - } - + let (_server, _) = Actor::spawn( + Some(internal::InternalSTTActor::name()), + internal::InternalSTTActor, + internal::InternalSTTArgs { + model_cache_dir: cache_dir, + model_type: whisper_model, + }, + ) + .await + .unwrap(); + + let base_url = internal_health().await.map(|r| r.0).unwrap(); Ok(base_url) } ServerType::External => { - if self - .state::() - .lock() - .await - .external_server - .is_some() - { + if registry::where_is(external::ExternalSTTActor::name()).is_some() { return Err(crate::Error::ServerAlreadyRunning); } @@ -349,18 +328,21 @@ impl> LocalSttPluginExt for T { .args(["serve"]) }; - let server = external::run_server(cmd, am_key).await?; - tokio::time::sleep(std::time::Duration::from_millis(250)).await; - let _ = server.init(am_model, data_dir).await; - let api_base = server.base_url.clone(); - - { - let state = self.state::(); - let mut s = state.lock().await; - s.external_server = Some(server); - } - - Ok(api_base) + let (_server, _) = Actor::spawn( + Some(external::ExternalSTTActor::name()), + external::ExternalSTTActor, + external::ExternalSTTArgs { + cmd, + api_key: am_key, + model: am_model, + models_dir: data_dir, + }, + ) + .await + .unwrap(); + + let base_url = external_health().await.map(|v| v.0).unwrap(); + Ok(base_url) } } } @@ -373,30 +355,45 @@ impl> LocalSttPluginExt for T { return Ok(false); } - let state = self.state::(); - let mut s = state.lock().await; - let mut stopped = false; match server_type { Some(ServerType::External) => { - hypr_host::kill_processes_by_matcher(hypr_host::ProcessMatcher::Sidecar); - - if let Some(_) = s.external_server.take() { - stopped = true; + if let Some(cell) = registry::where_is(external::ExternalSTTActor::name()) { + let actor: ActorRef = cell.into(); + if let Err(e) = actor.stop_and_wait(None, None).await { + tracing::error!("stop_server: {:?}", e); + } else { + stopped = true; + } } } Some(ServerType::Internal) => { - if let Some(_) = s.internal_server.take() { - stopped = true; + if let Some(cell) = registry::where_is(internal::InternalSTTActor::name()) { + let actor: ActorRef = cell.into(); + if let Err(e) = actor.stop_and_wait(None, None).await { + tracing::error!("stop_server: {:?}", e); + } else { + stopped = true; + } } } Some(ServerType::Custom) => {} None => { - if let Some(_) = s.external_server.take() { - stopped = true; + if let Some(cell) = registry::where_is(external::ExternalSTTActor::name()) { + let actor: ActorRef = cell.into(); + if let Err(e) = actor.stop_and_wait(None, None).await { + tracing::error!("stop_server: {:?}", e); + } else { + stopped = true; + } } - if let Some(_) = s.internal_server.take() { - stopped = true; + if let Some(cell) = registry::where_is(internal::InternalSTTActor::name()) { + let actor: ActorRef = cell.into(); + if let Err(e) = actor.stop_and_wait(None, None).await { + tracing::error!("stop_server: {:?}", e); + } else { + stopped = true; + } } } } @@ -406,21 +403,15 @@ impl> LocalSttPluginExt for T { #[tracing::instrument(skip_all)] async fn get_servers(&self) -> Result, crate::Error> { - let state = self.state::(); - let guard = state.lock().await; + let internal_health = internal_health() + .await + .map(|r| r.1) + .unwrap_or(ServerHealth::Unreachable); - let internal_health = if let Some(server) = &guard.internal_server { - let status = server.health().await; - status - } else { - ServerHealth::Unreachable - }; - - let external_health = if let Some(server) = &guard.external_server { - server.health().await - } else { - ServerHealth::Unreachable - }; + let external_health = external_health() + .await + .map(|r| r.1) + .unwrap_or(ServerHealth::Unreachable); let custom_health = { let provider = self.get_provider()?; @@ -634,3 +625,29 @@ impl> LocalSttPluginExt for T { Ok(()) } } + +async fn internal_health() -> Option<(String, ServerHealth)> { + match registry::where_is(internal::InternalSTTActor::name()) { + Some(cell) => { + let actor: ActorRef = cell.into(); + match call_t!(actor, internal::InternalSTTMessage::GetHealth, 10 * 1000) { + Ok(r) => Some(r), + Err(_) => None, + } + } + None => None, + } +} + +async fn external_health() -> Option<(String, ServerHealth)> { + match registry::where_is(external::ExternalSTTActor::name()) { + Some(cell) => { + let actor: ActorRef = cell.into(); + match call_t!(actor, external::ExternalSTTMessage::GetHealth, 10 * 1000) { + Ok(r) => Some(r), + Err(_) => None, + } + } + None => None, + } +} diff --git a/plugins/local-stt/src/lib.rs b/plugins/local-stt/src/lib.rs index 8883af529..11deece4c 100644 --- a/plugins/local-stt/src/lib.rs +++ b/plugins/local-stt/src/lib.rs @@ -15,6 +15,7 @@ pub use error::*; use events::*; pub use ext::*; pub use model::*; +pub use server::*; pub use store::*; pub use types::*; @@ -23,8 +24,6 @@ pub type SharedState = std::sync::Arc>; #[derive(Default)] pub struct State { pub am_api_key: Option, - pub internal_server: Option, - pub external_server: Option, pub download_task: HashMap, CancellationToken)>, } diff --git a/plugins/local-stt/src/server/external.rs b/plugins/local-stt/src/server/external.rs index 1b0fa9e52..a473846b7 100644 --- a/plugins/local-stt/src/server/external.rs +++ b/plugins/local-stt/src/server/external.rs @@ -1,125 +1,186 @@ +use std::path::PathBuf; +use tauri_plugin_shell::process::{Command, CommandChild}; + use super::ServerHealth; +use backon::{ConstantBuilder, Retryable}; +use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef, RpcReplyPort}; + +pub enum ExternalSTTMessage { + GetHealth(RpcReplyPort<(String, ServerHealth)>), + ProcessTerminated(String), +} -pub struct ServerHandle { - pub base_url: String, +pub struct ExternalSTTArgs { + pub cmd: Command, + pub api_key: String, + pub model: hypr_am::AmModel, + pub models_dir: PathBuf, +} + +pub struct ExternalSTTState { + base_url: String, api_key: Option, - shutdown: tokio::sync::watch::Sender<()>, + model: hypr_am::AmModel, + models_dir: PathBuf, client: hypr_am::Client, + process_handle: Option, + task_handle: Option>, } -impl Drop for ServerHandle { - fn drop(&mut self) { - tracing::info!("stopping"); - let _ = self.shutdown.send(()); +pub struct ExternalSTTActor; + +impl ExternalSTTActor { + pub fn name() -> ActorName { + "external_stt".into() } } -impl ServerHandle { - pub async fn health(&self) -> ServerHealth { - let res = self.client.status().await; - if res.is_err() { - tracing::error!("{:?}", res); - return ServerHealth::Unreachable; - } +impl Actor for ExternalSTTActor { + type Msg = ExternalSTTMessage; + type State = ExternalSTTState; + type Arguments = ExternalSTTArgs; - let res = res.unwrap(); + async fn pre_start( + &self, + myself: ActorRef, + args: Self::Arguments, + ) -> Result { + let port = port_check::free_local_port().unwrap(); + let (mut rx, child) = args.cmd.args(["--port", &port.to_string()]).spawn()?; + let base_url = format!("http://localhost:{}", port); + let client = hypr_am::Client::new(&base_url); - if res.model_state == hypr_am::ModelState::Loading { - return ServerHealth::Loading; - } + let task_handle = tokio::spawn(async move { + loop { + match rx.recv().await { + Some(tauri_plugin_shell::process::CommandEvent::Stdout(bytes)) + | Some(tauri_plugin_shell::process::CommandEvent::Stderr(bytes)) => { + if let Ok(text) = String::from_utf8(bytes) { + let text = text.trim(); + if !text.is_empty() + && !text.contains("[TranscriptionHandler]") + && !text.contains("[WebSocket]") + && !text.contains("Sent interim") + && !text.contains("/v1/status") + { + tracing::info!("{}", text); + } + } + } + Some(tauri_plugin_shell::process::CommandEvent::Terminated(payload)) => { + let e = format!("{:?}", payload); + tracing::error!("{}", e); + let _ = myself.send_message(ExternalSTTMessage::ProcessTerminated(e)); + break; + } + Some(tauri_plugin_shell::process::CommandEvent::Error(error)) => { + tracing::error!("{}", error); + let _ = myself.send_message(ExternalSTTMessage::ProcessTerminated(error)); + break; + } + None => { + tracing::warn!("closed"); + break; + } + _ => {} + } + } + }); - if res.model_state == hypr_am::ModelState::Loaded { - return ServerHealth::Ready; - } + Ok(ExternalSTTState { + base_url, + api_key: Some(args.api_key), + model: args.model, + models_dir: args.models_dir, + client, + process_handle: Some(child), + task_handle: Some(task_handle), + }) + } + async fn post_start( + &self, + _myself: ActorRef, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + let api_key = state.api_key.clone().unwrap(); + let model = state.model.clone(); + let models_dir = state.models_dir.clone(); + + let res = (|| async { + state + .client + .init( + hypr_am::InitRequest::new(api_key.clone()) + .with_model(model.clone(), &models_dir), + ) + .await + }) + .retry( + ConstantBuilder::default() + .with_max_times(20) + .with_delay(std::time::Duration::from_millis(500)), + ) + .when(|e| { + tracing::error!("external_stt_init_failed: {:?}", e); + true + }) + .sleep(tokio::time::sleep) + .await?; - ServerHealth::Unreachable + tracing::info!(res = ?res); + Ok(()) } - pub async fn init( + async fn post_stop( &self, - model: hypr_am::AmModel, - models_dir: impl AsRef, - ) -> Result { - let r = self - .client - .init( - hypr_am::InitRequest::new(self.api_key.clone().unwrap()) - .with_model(model, models_dir), - ) - .await?; - - Ok(r) + _myself: ActorRef, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + if let Some(process) = state.process_handle.take() { + if let Err(e) = process.kill() { + tracing::error!("failed_to_kill_process: {:?}", e); + } + } + + if let Some(task) = state.task_handle.take() { + task.abort(); + } + + hypr_host::kill_processes_by_matcher(hypr_host::ProcessMatcher::Sidecar); + + Ok(()) } -} -pub async fn run_server( - cmd: tauri_plugin_shell::process::Command, - am_key: String, -) -> Result { - let port = port_check::free_local_port().unwrap(); - let (mut rx, child) = cmd.args(["--port", &port.to_string()]).spawn()?; - let base_url = format!("http://localhost:{}", port); - let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(()); - let client = hypr_am::Client::new(&base_url); - - tokio::spawn(async move { - let mut process_ended = false; - - loop { - tokio::select! { - _ = shutdown_rx.changed() => { - tracing::info!("shutdown_signal_received"); - break; - } - event = rx.recv() => { - match event { - Some(tauri_plugin_shell::process::CommandEvent::Stdout(bytes)) => { - if let Ok(text) = String::from_utf8(bytes) { - let text = text.trim(); - if !text.is_empty() && !text.contains("[TranscriptionHandler]") && !text.contains("[WebSocket]") && !text.contains("Sent interim") { - tracing::info!("{}", text); - } - } - } - Some(tauri_plugin_shell::process::CommandEvent::Stderr(bytes)) => { - if let Ok(text) = String::from_utf8(bytes) { - let text = text.trim(); - if !text.is_empty() && !text.contains("[TranscriptionHandler]") && !text.contains("[WebSocket]") && !text.contains("Sent interim") { - tracing::info!("{}", text); - } - } - } - Some(tauri_plugin_shell::process::CommandEvent::Terminated(payload)) => { - tracing::error!("terminated: {:?}", payload); - process_ended = true; - break; - } - Some(tauri_plugin_shell::process::CommandEvent::Error(error)) => { - tracing::error!("{}", error); - break; - } - None => { - tracing::warn!("closed"); - process_ended = true; - break; - } - _ => {} + async fn handle( + &self, + myself: ActorRef, + message: Self::Msg, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match message { + ExternalSTTMessage::ProcessTerminated(e) => { + myself.stop(Some(e)); + Ok(()) + } + ExternalSTTMessage::GetHealth(reply_port) => { + let status = match state.client.status().await { + Ok(r) => match r.model_state { + hypr_am::ModelState::Loading => ServerHealth::Loading, + hypr_am::ModelState::Loaded => ServerHealth::Ready, + _ => ServerHealth::Unreachable, + }, + Err(e) => { + tracing::error!("{:?}", e); + ServerHealth::Unreachable } + }; + + if let Err(e) = reply_port.send((state.base_url.clone(), status)) { + return Err(e.into()); } - } - } - if !process_ended { - if let Err(e) = child.kill() { - tracing::error!("{:?}", e); + Ok(()) } } - }); - - Ok(ServerHandle { - api_key: Some(am_key), - base_url, - shutdown: shutdown_tx, - client, - }) + } } diff --git a/plugins/local-stt/src/server/internal.rs b/plugins/local-stt/src/server/internal.rs index d85823a89..8329523b5 100644 --- a/plugins/local-stt/src/server/internal.rs +++ b/plugins/local-stt/src/server/internal.rs @@ -3,166 +3,120 @@ use std::{ path::PathBuf, }; -use axum::{http::StatusCode, response::IntoResponse, routing::get, Router}; +use axum::{error_handling::HandleError, Router}; +use ractor::{Actor, ActorName, ActorProcessingErr, ActorRef, RpcReplyPort}; +use reqwest::StatusCode; use tower_http::cors::{self, CorsLayer}; use super::ServerHealth; use hypr_whisper_local_model::WhisperModel; -#[derive(Default)] -pub struct ServerStateBuilder { - pub model_type: Option, - pub model_cache_dir: Option, +pub enum InternalSTTMessage { + GetHealth(RpcReplyPort<(String, ServerHealth)>), + ServerError(String), } -impl ServerStateBuilder { - pub fn model_cache_dir(mut self, model_cache_dir: PathBuf) -> Self { - self.model_cache_dir = Some(model_cache_dir); - self - } - - pub fn model_type(mut self, model_type: WhisperModel) -> Self { - self.model_type = Some(model_type); - self - } - - pub fn build(self) -> ServerState { - ServerState { - model_type: self.model_type.unwrap(), - model_cache_dir: self.model_cache_dir.unwrap(), - } - } +pub struct InternalSTTArgs { + pub model_type: WhisperModel, + pub model_cache_dir: PathBuf, } -#[derive(Clone)] -pub struct ServerState { - model_type: WhisperModel, - model_cache_dir: PathBuf, +pub struct InternalSTTState { + base_url: String, + shutdown: tokio::sync::watch::Sender<()>, + server_task: tokio::task::JoinHandle<()>, } -impl ServerState { - pub fn builder() -> ServerStateBuilder { - ServerStateBuilder::default() +pub struct InternalSTTActor; + +impl InternalSTTActor { + pub fn name() -> ActorName { + "internal_stt".into() } } -#[derive(Clone)] -pub struct ServerHandle { - pub base_url: String, - pub api_key: Option, - shutdown: tokio::sync::watch::Sender<()>, -} +impl Actor for InternalSTTActor { + type Msg = InternalSTTMessage; + type State = InternalSTTState; + type Arguments = InternalSTTArgs; + + async fn pre_start( + &self, + myself: ActorRef, + args: Self::Arguments, + ) -> Result { + let model_path = args.model_cache_dir.join(args.model_type.file_name()); + + let whisper_service = HandleError::new( + hypr_transcribe_whisper_local::TranscribeService::builder() + .model_path(model_path) + .build(), + move |err: String| async move { + let _ = myself.send_message(InternalSTTMessage::ServerError(err.clone())); + (StatusCode::INTERNAL_SERVER_ERROR, err) + }, + ); + + let router = Router::new() + .route_service("/v1/listen", whisper_service) + .layer( + CorsLayer::new() + .allow_origin(cors::Any) + .allow_methods(cors::Any) + .allow_headers(cors::Any), + ); + + let listener = + tokio::net::TcpListener::bind(SocketAddr::from((Ipv4Addr::LOCALHOST, 0))).await?; + + let server_addr = listener.local_addr()?; + let base_url = format!("http://{}", server_addr); + + let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(()); + + let server_task = tokio::spawn(async move { + axum::serve(listener, router) + .with_graceful_shutdown(async move { + shutdown_rx.changed().await.ok(); + }) + .await + .unwrap(); + }); + + Ok(InternalSTTState { + base_url, + shutdown: shutdown_tx, + server_task, + }) + } -impl Drop for ServerHandle { - fn drop(&mut self) { - tracing::info!("stopping: {}", self.base_url); - let _ = self.shutdown.send(()); + async fn post_stop( + &self, + _myself: ActorRef, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + let _ = state.shutdown.send(()); + state.server_task.abort(); + Ok(()) } -} -impl ServerHandle { - pub async fn health(&self) -> ServerHealth { - let client = reqwest::Client::new(); - let url = format!("{}/health", self.base_url); - - match client - .get(&url) - .timeout(std::time::Duration::from_secs(2)) - .send() - .await - { - Ok(res) if res.status().is_success() => ServerHealth::Ready, - Ok(_res) => ServerHealth::Unreachable, - Err(e) => { - tracing::error!("{:?}", e); - ServerHealth::Unreachable + async fn handle( + &self, + _myself: ActorRef, + message: Self::Msg, + state: &mut Self::State, + ) -> Result<(), ActorProcessingErr> { + match message { + InternalSTTMessage::GetHealth(reply_port) => { + let status = ServerHealth::Ready; + + if let Err(e) = reply_port.send((state.base_url.clone(), status)) { + return Err(e.into()); + } + + Ok(()) } + InternalSTTMessage::ServerError(e) => Err(e.into()), } } } - -pub async fn run_server(state: ServerState) -> Result { - tracing::info!("starting"); - let router = make_service_router(state); - - let listener = - tokio::net::TcpListener::bind(SocketAddr::from((Ipv4Addr::LOCALHOST, 0))).await?; - - let server_addr = listener.local_addr()?; - let base_url = format!("http://{}", server_addr); - - let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(()); - - let server_handle = ServerHandle { - base_url, - api_key: None, - shutdown: shutdown_tx, - }; - - tokio::spawn(async move { - axum::serve(listener, router) - .with_graceful_shutdown(async move { - shutdown_rx.changed().await.ok(); - }) - .await - .unwrap(); - }); - - tracing::info!("local_stt_server_started {}", server_addr); - Ok(server_handle) -} - -fn make_service_router(state: ServerState) -> Router { - let model_path = state.model_cache_dir.join(state.model_type.file_name()); - - let whisper_service = hypr_transcribe_whisper_local::TranscribeService::builder() - .model_path(model_path) - .build(); - - Router::new() - .route("/health", get(health)) - .route_service("/v1/listen", whisper_service) - .layer( - CorsLayer::new() - .allow_origin(cors::Any) - .allow_methods(cors::Any) - .allow_headers(cors::Any), - ) -} - -async fn health() -> impl IntoResponse { - StatusCode::OK -} - -#[cfg(test)] -mod tests { - use super::*; - - use axum::body::Body; - use axum::http::{Request, StatusCode}; - use tower::ServiceExt; - - use hypr_whisper_local_model::WhisperModel; - - #[tokio::test] - async fn test_health_endpoint() { - let state = ServerStateBuilder::default() - .model_cache_dir(dirs::data_dir().unwrap().join("com.hyprnote.dev/stt")) - .model_type(WhisperModel::QuantizedTinyEn) - .build(); - - let app = make_service_router(state); - - let response = app - .oneshot( - Request::builder() - .uri("/health") - .body(Body::empty()) - .unwrap(), - ) - .await - .unwrap(); - - assert_eq!(response.status(), StatusCode::OK); - } -}