diff --git a/Cargo.lock b/Cargo.lock index d80d53f831..6cd4b8f17c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3924,6 +3924,7 @@ dependencies = [ name = "desktop" version = "0.0.0" dependencies = [ + "host", "ractor", "ractor-supervisor", "sentry", diff --git a/apps/desktop/src-tauri/src/lib.rs b/apps/desktop/src-tauri/src/lib.rs index 7c19fe368f..e562f27c23 100644 --- a/apps/desktop/src-tauri/src/lib.rs +++ b/apps/desktop/src-tauri/src/lib.rs @@ -80,11 +80,7 @@ pub async fn main() { .plugin(tauri_plugin_store::Builder::default().build()) .plugin(tauri_plugin_store2::init()) .plugin(tauri_plugin_windows::init()) - .plugin(tauri_plugin_listener::init( - tauri_plugin_listener::InitOptions { - parent_supervisor: root_supervisor.as_ref().map(|s| s.get_cell()), - }, - )) + .plugin(tauri_plugin_listener::init()) .plugin(tauri_plugin_listener2::init()) .plugin(tauri_plugin_local_stt::init( tauri_plugin_local_stt::InitOptions { diff --git a/apps/desktop/src/store/zustand/listener/general.ts b/apps/desktop/src/store/zustand/listener/general.ts index 1d8df835f5..f5494be9cd 100644 --- a/apps/desktop/src/store/zustand/listener/general.ts +++ b/apps/desktop/src/store/zustand/listener/general.ts @@ -6,10 +6,10 @@ import type { StoreApi } from "zustand"; import { commands as hooksCommands } from "@hypr/plugin-hooks"; import { - type ControllerParams, commands as listenerCommands, events as listenerEvents, type SessionEvent, + type SessionParams, type StreamResponse, } from "@hypr/plugin-listener"; import { @@ -48,7 +48,7 @@ export type GeneralState = { export type GeneralActions = { start: ( - params: ControllerParams, + params: SessionParams, options?: { handlePersist?: HandlePersistCallback }, ) => void; stop: () => void; @@ -80,7 +80,7 @@ const listenToSessionEvents = ( catch: (error) => error, }); -const startSessionEffect = (params: ControllerParams) => +const startSessionEffect = (params: SessionParams) => fromResult(listenerCommands.startSession(params)); const stopSessionEffect = () => fromResult(listenerCommands.stopSession()); @@ -95,7 +95,7 @@ export const createGeneralSlice = < get: StoreApi["getState"], ): GeneralState & GeneralActions => ({ ...initialState, - start: (params: ControllerParams, options) => { + start: (params: SessionParams, options) => { const targetSessionId = params.session_id; if (!targetSessionId) { diff --git a/plugins/listener/js/bindings.gen.ts b/plugins/listener/js/bindings.gen.ts index 377cd83fff..df83a80cce 100644 --- a/plugins/listener/js/bindings.gen.ts +++ b/plugins/listener/js/bindings.gen.ts @@ -47,7 +47,7 @@ async setMicMuted(muted: boolean) : Promise> { else return { status: "error", error: e as any }; } }, -async startSession(params: ControllerParams) : Promise> { +async startSession(params: SessionParams) : Promise> { try { return { status: "ok", data: await TAURI_INVOKE("plugin:listener|start_session", { params }) }; } catch (e) { @@ -88,8 +88,8 @@ sessionEvent: "plugin:listener:session-event" /** user-defined types **/ -export type ControllerParams = { session_id: string; languages: string[]; onboarding: boolean; record_enabled: boolean; model: string; base_url: string; api_key: string; keywords: string[] } export type SessionEvent = { type: "inactive"; session_id: string } | { type: "running_active"; session_id: string } | { type: "finalizing"; session_id: string } | { type: "audioAmplitude"; session_id: string; mic: number; speaker: number } | { type: "micMuted"; session_id: string; value: boolean } | { type: "streamResponse"; session_id: string; response: StreamResponse } +export type SessionParams = { session_id: string; languages: string[]; onboarding: boolean; record_enabled: boolean; model: string; base_url: string; api_key: string; keywords: string[] } export type StreamAlternatives = { transcript: string; words: StreamWord[]; confidence: number; languages?: string[] } export type StreamChannel = { alternatives: StreamAlternatives[] } export type StreamExtra = { started_unix_millis: number } diff --git a/plugins/listener/src/actors/controller.rs b/plugins/listener/src/actors/controller.rs deleted file mode 100644 index 8fa59e95cd..0000000000 --- a/plugins/listener/src/actors/controller.rs +++ /dev/null @@ -1,362 +0,0 @@ -use std::time::{Instant, SystemTime}; - -use tauri_specta::Event; -use tokio_util::sync::CancellationToken; - -use ractor::{ - call_t, concurrency, registry, Actor, ActorCell, ActorName, ActorProcessingErr, ActorRef, - RpcReplyPort, SupervisionEvent, -}; - -use crate::{ - actors::{ - ListenerActor, ListenerArgs, ListenerMsg, RecArgs, RecMsg, RecorderActor, SourceActor, - SourceArgs, SourceMsg, - }, - SessionEvent, -}; - -#[derive(Debug)] -pub enum ControllerMsg { - SetMicMute(bool), - GetMicMute(RpcReplyPort), - GetMicDeviceName(RpcReplyPort>), - ChangeMicDevice(Option), - GetSessionId(RpcReplyPort), -} - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, specta::Type)] -pub struct ControllerParams { - pub session_id: String, - pub languages: Vec, - pub onboarding: bool, - pub record_enabled: bool, - pub model: String, - pub base_url: String, - pub api_key: String, - pub keywords: Vec, -} - -pub struct ControllerArgs { - pub app: tauri::AppHandle, - pub params: ControllerParams, -} - -pub struct ControllerState { - app: tauri::AppHandle, - token: CancellationToken, - params: ControllerParams, - started_at_instant: Instant, - started_at_system: SystemTime, -} - -pub struct ControllerActor; - -impl ControllerActor { - pub fn name() -> ActorName { - "controller".into() - } -} - -#[ractor::async_trait] -impl Actor for ControllerActor { - type Msg = ControllerMsg; - type State = ControllerState; - type Arguments = ControllerArgs; - - async fn pre_start( - &self, - myself: ActorRef, - args: Self::Arguments, - ) -> Result { - let cancellation_token = CancellationToken::new(); - let started_at_instant = Instant::now(); - let started_at_system = SystemTime::now(); - - { - use tauri_plugin_tray::TrayPluginExt; - let _ = args.app.set_start_disabled(true); - } - - let state = ControllerState { - app: args.app, - token: cancellation_token, - params: args.params, - started_at_instant, - started_at_system, - }; - - { - let c = myself.get_cell(); - Self::start_all_actors(c, &state).await?; - } - - SessionEvent::RunningActive { - session_id: state.params.session_id.clone(), - } - .emit(&state.app) - .unwrap(); - Ok(state) - } - - async fn handle( - &self, - _myself: ActorRef, - message: Self::Msg, - state: &mut Self::State, - ) -> Result<(), ActorProcessingErr> { - match message { - ControllerMsg::SetMicMute(muted) => { - if let Some(cell) = registry::where_is(SourceActor::name()) { - let actor: ActorRef = cell.into(); - actor.cast(SourceMsg::SetMicMute(muted))?; - } - SessionEvent::MicMuted { - session_id: state.params.session_id.clone(), - value: muted, - } - .emit(&state.app)?; - } - - ControllerMsg::GetMicDeviceName(reply) => { - if !reply.is_closed() { - let device_name = if let Some(cell) = registry::where_is(SourceActor::name()) { - let actor: ActorRef = cell.into(); - call_t!(actor, SourceMsg::GetMicDevice, 100).unwrap_or(None) - } else { - None - }; - - let _ = reply.send(device_name); - } - } - - ControllerMsg::GetMicMute(reply) => { - let muted = if let Some(cell) = registry::where_is(SourceActor::name()) { - let actor: ActorRef = cell.into(); - call_t!(actor, SourceMsg::GetMicMute, 100)? - } else { - false - }; - - if !reply.is_closed() { - let _ = reply.send(muted); - } - } - - ControllerMsg::ChangeMicDevice(device) => { - if let Some(cell) = registry::where_is(SourceActor::name()) { - let actor: ActorRef = cell.into(); - actor.cast(SourceMsg::SetMicDevice(device))?; - } - } - - ControllerMsg::GetSessionId(reply) => { - if !reply.is_closed() { - let _ = reply.send(state.params.session_id.clone()); - } - } - } - - Ok(()) - } - - async fn handle_supervisor_evt( - &self, - myself: ActorRef, - event: SupervisionEvent, - _state: &mut Self::State, - ) -> Result<(), ActorProcessingErr> { - match event { - SupervisionEvent::ActorStarted(actor) => { - tracing::info!("{:?}_actor_started", actor.get_name()); - } - SupervisionEvent::ActorTerminated(actor, _maybe_state, exit_reason) => { - let actor_name = actor - .get_name() - .map(|n| n.to_string()) - .unwrap_or_else(|| "unknown".to_string()); - - tracing::error!( - actor = %actor_name, - reason = ?exit_reason, - "child_actor_terminated_stopping_session" - ); - - myself.stop(None); - } - SupervisionEvent::ActorFailed(_, _) => {} - _ => {} - } - - Ok(()) - } - - async fn post_stop( - &self, - _myself: ActorRef, - state: &mut Self::State, - ) -> Result<(), ActorProcessingErr> { - state.token.cancel(); - - { - Self::stop_all_actors().await; - } - - { - use tauri_plugin_tray::TrayPluginExt; - let _ = state.app.set_start_disabled(false); - } - - SessionEvent::Inactive { - session_id: state.params.session_id.clone(), - } - .emit(&state.app)?; - - Ok(()) - } -} - -impl ControllerActor { - async fn start_all_actors( - supervisor: ActorCell, - state: &ControllerState, - ) -> Result<(), ActorProcessingErr> { - Self::start_source(supervisor.clone(), state).await?; - Self::start_listener(supervisor.clone(), state, None).await?; - - if state.params.record_enabled { - Self::start_recorder(supervisor, state).await?; - } - - Ok(()) - } - - async fn stop_all_actors() { - Self::stop_source().await; - Self::stop_listener().await; - Self::stop_recorder().await; - } - - async fn start_source( - supervisor: ActorCell, - state: &ControllerState, - ) -> Result, ActorProcessingErr> { - let (ar, _) = Actor::spawn_linked( - Some(SourceActor::name()), - SourceActor, - SourceArgs { - token: state.token.clone(), - mic_device: None, - onboarding: state.params.onboarding, - app: state.app.clone(), - session_id: state.params.session_id.clone(), - }, - supervisor, - ) - .await?; - Ok(ar) - } - - async fn stop_source() { - if let Some(cell) = registry::where_is(SourceActor::name()) { - let actor: ActorRef = cell.into(); - let _ = actor - .stop_and_wait( - Some("restart".to_string()), - Some(concurrency::Duration::from_secs(3)), - ) - .await; - } - } - - async fn start_recorder( - supervisor: ActorCell, - state: &ControllerState, - ) -> Result, ActorProcessingErr> { - use tauri::{path::BaseDirectory, Manager}; - let app_dir = state - .app - .path() - .resolve("hyprnote/sessions", BaseDirectory::Data) - .map_err(|e| Box::new(e) as ActorProcessingErr)?; - let (rec_ref, _) = Actor::spawn_linked( - Some(RecorderActor::name()), - RecorderActor, - RecArgs { - app_dir, - session_id: state.params.session_id.clone(), - }, - supervisor, - ) - .await?; - Ok(rec_ref) - } - - async fn stop_recorder() { - if let Some(cell) = registry::where_is(RecorderActor::name()) { - let actor: ActorRef = cell.into(); - let _ = actor - .stop_and_wait( - Some("restart".to_string()), - Some(concurrency::Duration::from_secs(6)), - ) - .await; - } - } - - async fn start_listener( - supervisor: ActorCell, - session_state: &ControllerState, - listener_args: Option, - ) -> Result, ActorProcessingErr> { - use crate::actors::ChannelMode; - - let mode = if listener_args.is_none() { - if let Some(cell) = registry::where_is(SourceActor::name()) { - let actor: ActorRef = cell.into(); - match call_t!(actor, SourceMsg::GetMode, 500) { - Ok(m) => m, - Err(_) => ChannelMode::Dual, - } - } else { - ChannelMode::Dual - } - } else { - ChannelMode::Dual - }; - - let (listen_ref, _) = Actor::spawn_linked( - Some(ListenerActor::name()), - ListenerActor, - listener_args.unwrap_or(ListenerArgs { - app: session_state.app.clone(), - languages: session_state.params.languages.clone(), - onboarding: session_state.params.onboarding, - model: session_state.params.model.clone(), - base_url: session_state.params.base_url.clone(), - api_key: session_state.params.api_key.clone(), - keywords: session_state.params.keywords.clone(), - mode, - session_started_at: session_state.started_at_instant, - session_started_at_unix: session_state.started_at_system, - session_id: session_state.params.session_id.clone(), - }), - supervisor, - ) - .await?; - Ok(listen_ref) - } - - async fn stop_listener() { - if let Some(cell) = registry::where_is(ListenerActor::name()) { - let actor: ActorRef = cell.into(); - let _ = actor - .stop_and_wait( - Some("restart".to_string()), - Some(concurrency::Duration::from_secs(3)), - ) - .await; - } - } -} diff --git a/plugins/listener/src/actors/mod.rs b/plugins/listener/src/actors/mod.rs index f1794d007a..3547e3d4a3 100644 --- a/plugins/listener/src/actors/mod.rs +++ b/plugins/listener/src/actors/mod.rs @@ -1,9 +1,7 @@ -mod controller; mod listener; mod recorder; mod source; -pub use controller::*; pub use listener::*; pub use recorder::*; pub use source::*; @@ -21,5 +19,5 @@ pub enum ChannelMode { #[derive(Clone)] pub struct AudioChunk { - data: Vec, + pub data: Vec, } diff --git a/plugins/listener/src/actors/source.rs b/plugins/listener/src/actors/source.rs index 0682d6ddae..dbbdbb5255 100644 --- a/plugins/listener/src/actors/source.rs +++ b/plugins/listener/src/actors/source.rs @@ -24,28 +24,28 @@ use hypr_vad_ext::VadMaskExt; use tauri_specta::Event; const AUDIO_AMPLITUDE_THROTTLE: Duration = Duration::from_millis(100); +const MAX_BUFFER_CHUNKS: usize = 150; pub enum SourceMsg { SetMicMute(bool), GetMicMute(RpcReplyPort), SetMicDevice(Option), GetMicDevice(RpcReplyPort>), - GetMode(RpcReplyPort), + GetSessionId(RpcReplyPort), MicChunk(AudioChunk), SpeakerChunk(AudioChunk), } pub struct SourceArgs { pub mic_device: Option, - pub token: CancellationToken, pub onboarding: bool, pub app: tauri::AppHandle, pub session_id: String, } pub struct SourceState { + session_id: String, mic_device: Option, - token: CancellationToken, onboarding: bool, mic_muted: Arc, run_task: Option>, @@ -135,8 +135,8 @@ impl Actor for SourceActor { let pipeline = Pipeline::new(args.app.clone(), args.session_id.clone()); let mut st = SourceState { + session_id: args.session_id, mic_device, - token: args.token, onboarding: args.onboarding, mic_muted: Arc::new(AtomicBool::new(false)), run_task: None, @@ -171,6 +171,11 @@ impl Actor for SourceActor { let _ = reply.send(st.mic_device.clone()); } } + SourceMsg::GetSessionId(reply) => { + if !reply.is_closed() { + let _ = reply.send(st.session_id.clone()); + } + } SourceMsg::SetMicDevice(dev) => { st.mic_device = dev; st.pipeline.reset(); @@ -184,11 +189,6 @@ impl Actor for SourceActor { } start_source_loop(&myself, st).await?; } - SourceMsg::GetMode(reply) => { - if !reply.is_closed() { - let _ = reply.send(st.current_mode); - } - } SourceMsg::MicChunk(chunk) => { st.pipeline.ingest_mic(chunk); st.pipeline.flush(st.current_mode); @@ -266,7 +266,6 @@ async fn start_source_loop_single( #[cfg(any(target_os = "macos", target_os = "linux"))] { let myself2 = myself.clone(); - let token = st.token.clone(); let mic_muted = st.mic_muted.clone(); let mic_device = st.mic_device.clone(); @@ -287,11 +286,6 @@ async fn start_source_loop_single( loop { tokio::select! { - _ = token.cancelled() => { - drop(mic_stream); - myself2.stop(None); - return; - } _ = stream_cancel_token.cancelled() => { drop(mic_stream); return; @@ -326,7 +320,6 @@ async fn start_source_loop_dual( st: &mut SourceState, ) -> Result<(), ActorProcessingErr> { let myself2 = myself.clone(); - let token = st.token.clone(); let mic_muted = st.mic_muted.clone(); let mic_device = st.mic_device.clone(); @@ -358,12 +351,6 @@ async fn start_source_loop_dual( loop { tokio::select! { - _ = token.cancelled() => { - drop(mic_stream); - drop(spk_stream); - myself2.stop(None); - return; - } _ = stream_cancel_token.cancelled() => { drop(mic_stream); drop(spk_stream); @@ -410,6 +397,7 @@ struct Pipeline { aec: Option, joiner: Joiner, amplitude: AmplitudeEmitter, + audio_buffer: AudioBuffer, } impl Pipeline { @@ -420,6 +408,7 @@ impl Pipeline { aec: None, joiner: Joiner::new(), amplitude: AmplitudeEmitter::new(app, session_id), + audio_buffer: AudioBuffer::new(MAX_BUFFER_CHUNKS), } } @@ -431,6 +420,7 @@ impl Pipeline { aec.reset(); } self.amplitude.reset(); + self.audio_buffer.clear(); } fn ingest_mic(&mut self, chunk: AudioChunk) { @@ -485,27 +475,85 @@ impl Pipeline { } let Some(cell) = registry::where_is(ListenerActor::name()) else { - tracing::debug!(actor = ListenerActor::name(), "unavailable"); + self.audio_buffer.push(processed_mic, processed_spk, mode); + tracing::debug!( + actor = ListenerActor::name(), + buffered = self.audio_buffer.len(), + "listener_unavailable_buffering" + ); return; }; let actor: ActorRef = cell.into(); + self.flush_buffer_to_listener(&actor, mode); + + self.send_to_listener(&actor, &processed_mic, &processed_spk, mode); + + self.amplitude.observe(processed_mic, processed_spk); + } + + fn flush_buffer_to_listener(&mut self, actor: &ActorRef, mode: ChannelMode) { + while let Some((mic, spk, buffered_mode)) = self.audio_buffer.pop() { + if buffered_mode == mode { + self.send_to_listener(actor, &mic, &spk, mode); + } + } + } + + fn send_to_listener( + &self, + actor: &ActorRef, + mic: &Arc<[f32]>, + spk: &Arc<[f32]>, + mode: ChannelMode, + ) { let result = if mode == ChannelMode::Single { - let audio_bytes = f32_to_i16_bytes(processed_mic.to_vec().iter().copied()); + let audio_bytes = f32_to_i16_bytes(mic.to_vec().iter().copied()); actor.cast(ListenerMsg::AudioSingle(audio_bytes)) } else { - let mic_bytes = f32_to_i16_bytes(processed_mic.iter().copied()); - let spk_bytes = f32_to_i16_bytes(processed_spk.iter().copied()); + let mic_bytes = f32_to_i16_bytes(mic.iter().copied()); + let spk_bytes = f32_to_i16_bytes(spk.iter().copied()); actor.cast(ListenerMsg::AudioDual(mic_bytes, spk_bytes)) }; if result.is_err() { tracing::warn!(actor = ListenerActor::name(), "cast_failed"); - return; } + } +} - self.amplitude.observe(processed_mic, processed_spk); +struct AudioBuffer { + buffer: VecDeque<(Arc<[f32]>, Arc<[f32]>, ChannelMode)>, + max_size: usize, +} + +impl AudioBuffer { + fn new(max_size: usize) -> Self { + Self { + buffer: VecDeque::new(), + max_size, + } + } + + fn push(&mut self, mic: Arc<[f32]>, spk: Arc<[f32]>, mode: ChannelMode) { + if self.buffer.len() >= self.max_size { + self.buffer.pop_front(); + tracing::warn!("audio_buffer_overflow"); + } + self.buffer.push_back((mic, spk, mode)); + } + + fn pop(&mut self) -> Option<(Arc<[f32]>, Arc<[f32]>, ChannelMode)> { + self.buffer.pop_front() + } + + fn len(&self) -> usize { + self.buffer.len() + } + + fn clear(&mut self) { + self.buffer.clear(); } } diff --git a/plugins/listener/src/commands.rs b/plugins/listener/src/commands.rs index 7840e4ada1..9d2cbc04fc 100644 --- a/plugins/listener/src/commands.rs +++ b/plugins/listener/src/commands.rs @@ -1,4 +1,4 @@ -use crate::{actors::ControllerParams, ListenerPluginExt}; +use crate::{supervisor::SessionParams, ListenerPluginExt}; #[tauri::command] #[specta::specta] @@ -51,7 +51,7 @@ pub async fn set_mic_muted( #[specta::specta] pub async fn start_session( app: tauri::AppHandle, - params: ControllerParams, + params: SessionParams, ) -> Result<(), String> { app.start_session(params).await; Ok(()) diff --git a/plugins/listener/src/ext.rs b/plugins/listener/src/ext.rs index f3576da86d..3a9a5e524f 100644 --- a/plugins/listener/src/ext.rs +++ b/plugins/listener/src/ext.rs @@ -1,10 +1,13 @@ use std::future::Future; +use std::time::{Instant, SystemTime}; -use ractor::{call_t, concurrency, registry, Actor, ActorRef}; +use ractor::{call_t, registry, ActorRef}; +use tauri::{path::BaseDirectory, Manager}; use tauri_specta::Event; use crate::{ - actors::{ControllerActor, ControllerArgs, ControllerMsg, ControllerParams}, + actors::{SourceActor, SourceMsg}, + supervisor::{spawn_session_supervisor, SessionContext, SessionParams}, SessionEvent, }; @@ -23,7 +26,7 @@ pub trait ListenerPluginExt { fn get_state(&self) -> impl Future; fn stop_session(&self) -> impl Future; - fn start_session(&self, params: ControllerParams) -> impl Future; + fn start_session(&self, params: SessionParams) -> impl Future; } impl> ListenerPluginExt for T { @@ -34,15 +37,14 @@ impl> ListenerPluginExt for T { #[tracing::instrument(skip_all)] async fn get_current_microphone_device(&self) -> Result, crate::Error> { - if let Some(cell) = registry::where_is(ControllerActor::name()) { - let actor: ActorRef = cell.into(); - - match call_t!(actor, ControllerMsg::GetMicDeviceName, 500) { + if let Some(cell) = registry::where_is(SourceActor::name()) { + let actor: ActorRef = cell.into(); + match call_t!(actor, SourceMsg::GetMicDevice, 500) { Ok(device_name) => Ok(device_name), Err(_) => Ok(None), } } else { - Err(crate::Error::ActorNotFound(ControllerActor::name())) + Err(crate::Error::ActorNotFound(SourceActor::name())) } } @@ -51,29 +53,25 @@ impl> ListenerPluginExt for T { &self, device_name: impl Into, ) -> Result<(), crate::Error> { - if let Some(cell) = registry::where_is(ControllerActor::name()) { - let actor: ActorRef = cell.into(); - let _ = actor.cast(ControllerMsg::ChangeMicDevice(Some(device_name.into()))); + if let Some(cell) = registry::where_is(SourceActor::name()) { + let actor: ActorRef = cell.into(); + let _ = actor.cast(SourceMsg::SetMicDevice(Some(device_name.into()))); } - Ok(()) } #[tracing::instrument(skip_all)] async fn get_state(&self) -> crate::fsm::State { - if registry::where_is(ControllerActor::name()).is_some() { - crate::fsm::State::RunningActive - } else { - crate::fsm::State::Inactive - } + let state = self.state::(); + let guard = state.lock().await; + guard.get_state() } #[tracing::instrument(skip_all)] async fn get_mic_muted(&self) -> bool { - if let Some(cell) = registry::where_is(ControllerActor::name()) { - let actor: ActorRef = cell.into(); - - match call_t!(actor, ControllerMsg::GetMicMute, 100) { + if let Some(cell) = registry::where_is(SourceActor::name()) { + let actor: ActorRef = cell.into(); + match call_t!(actor, SourceMsg::GetMicMute, 100) { Ok(muted) => muted, Err(_) => false, } @@ -84,48 +82,106 @@ impl> ListenerPluginExt for T { #[tracing::instrument(skip_all)] async fn set_mic_muted(&self, muted: bool) { - if let Some(cell) = registry::where_is(ControllerActor::name()) { - let actor: ActorRef = cell.into(); - let _ = actor.cast(ControllerMsg::SetMicMute(muted)); + if let Some(cell) = registry::where_is(SourceActor::name()) { + let actor: ActorRef = cell.into(); + let _ = actor.cast(SourceMsg::SetMicMute(muted)); } } #[tracing::instrument(skip_all)] - async fn start_session(&self, params: ControllerParams) { + async fn start_session(&self, params: SessionParams) { let state = self.state::(); - let guard = state.lock().await; + let mut guard = state.lock().await; + + if guard.session_supervisor.is_some() { + tracing::warn!("session_already_running"); + return; + } + + let app_dir = match guard + .app + .path() + .resolve("hyprnote/sessions", BaseDirectory::Data) + { + Ok(dir) => dir, + Err(e) => { + tracing::error!(error = ?e, "failed_to_resolve_app_dir"); + return; + } + }; + + { + use tauri_plugin_tray::TrayPluginExt; + let _ = guard.app.set_start_disabled(true); + } + + let ctx = SessionContext { + app: guard.app.clone(), + params: params.clone(), + app_dir, + started_at_instant: Instant::now(), + started_at_system: SystemTime::now(), + }; + + match spawn_session_supervisor(ctx).await { + Ok((supervisor_ref, handle)) => { + guard.session_supervisor = Some(supervisor_ref); + guard.supervisor_handle = Some(handle); + + SessionEvent::RunningActive { + session_id: params.session_id, + } + .emit(&guard.app) + .unwrap(); + + tracing::info!("session_started"); + } + Err(e) => { + tracing::error!(error = ?e, "failed_to_start_session"); - let _ = Actor::spawn( - Some(ControllerActor::name()), - ControllerActor, - ControllerArgs { - app: guard.app.clone(), - params, - }, - ) - .await; + use tauri_plugin_tray::TrayPluginExt; + let _ = guard.app.set_start_disabled(false); + } + } } #[tracing::instrument(skip_all)] async fn stop_session(&self) { - if let Some(cell) = registry::where_is(ControllerActor::name()) { - let actor: ActorRef = cell.into(); - - let session_id = call_t!(actor, ControllerMsg::GetSessionId, 100).ok(); - - { - let state = self.state::(); - let guard = state.lock().await; - if let Some(session_id) = session_id.clone() { - SessionEvent::Finalizing { session_id } - .emit(&guard.app) - .unwrap(); - } - } + let state = self.state::(); + let mut guard = state.lock().await; + + let session_id = if let Some(cell) = registry::where_is(SourceActor::name()) { + let actor: ActorRef = cell.into(); + call_t!(actor, SourceMsg::GetSessionId, 100).ok() + } else { + None + }; - let _ = actor - .stop_and_wait(None, Some(concurrency::Duration::from_secs(10))) - .await; + if let Some(session_id) = session_id.clone() { + SessionEvent::Finalizing { session_id } + .emit(&guard.app) + .unwrap(); } + + if let Some(supervisor_cell) = guard.session_supervisor.take() { + supervisor_cell.stop(None); + } + + if let Some(handle) = guard.supervisor_handle.take() { + let _ = handle.await; + } + + { + use tauri_plugin_tray::TrayPluginExt; + let _ = guard.app.set_start_disabled(false); + } + + if let Some(session_id) = session_id { + SessionEvent::Inactive { session_id } + .emit(&guard.app) + .unwrap(); + } + + tracing::info!("session_stopped"); } } diff --git a/plugins/listener/src/lib.rs b/plugins/listener/src/lib.rs index cef67f9652..0089e3bf2b 100644 --- a/plugins/listener/src/lib.rs +++ b/plugins/listener/src/lib.rs @@ -1,5 +1,3 @@ -use ractor::ActorCell; -use ractor_supervisor::dynamic::DynamicSupervisorMsg; use tauri::Manager; use tokio::sync::Mutex; @@ -14,7 +12,7 @@ mod supervisor; pub use error::*; pub use events::*; pub use ext::*; -pub use supervisor::{SupervisorHandle, SupervisorRef, SUPERVISOR_NAME}; +pub use supervisor::{session_supervisor_name, SessionContext, SessionParams}; const PLUGIN_NAME: &str = "listener"; @@ -22,18 +20,13 @@ pub type SharedState = std::sync::Arc>; pub struct State { pub app: tauri::AppHandle, - pub listener_supervisor: Option>, - pub supervisor_handle: Option, -} - -#[derive(Default)] -pub struct InitOptions { - pub parent_supervisor: Option, + pub session_supervisor: Option, + pub supervisor_handle: Option>, } impl State { - pub async fn get_state(&self) -> fsm::State { - if ractor::registry::where_is(actors::ControllerActor::name()).is_some() { + pub fn get_state(&self) -> fsm::State { + if self.session_supervisor.is_some() { crate::fsm::State::RunningActive } else { crate::fsm::State::Inactive @@ -58,7 +51,7 @@ fn make_specta_builder() -> tauri_specta::Builder { .error_handling(tauri_specta::ErrorHandlingMode::Result) } -pub fn init(options: InitOptions) -> tauri::plugin::TauriPlugin { +pub fn init() -> tauri::plugin::TauriPlugin { let specta_builder = make_specta_builder(); tauri::plugin::Builder::new(PLUGIN_NAME) @@ -70,26 +63,11 @@ pub fn init(options: InitOptions) -> tauri::plugin::TauriPlugin { let state: SharedState = std::sync::Arc::new(Mutex::new(State { app: app_handle, - listener_supervisor: None, + session_supervisor: None, supervisor_handle: None, })); - app.manage(state.clone()); - - let parent = options.parent_supervisor.clone(); - tauri::async_runtime::spawn(async move { - match supervisor::spawn_listener_supervisor(parent).await { - Ok((supervisor, handle)) => { - let mut guard = state.lock().await; - guard.listener_supervisor = Some(supervisor); - guard.supervisor_handle = Some(handle); - tracing::info!("listener_supervisor_spawned"); - } - Err(e) => { - tracing::error!("failed_to_spawn_listener_supervisor: {:?}", e); - } - } - }); + app.manage(state); Ok(()) }) diff --git a/plugins/listener/src/supervisor.rs b/plugins/listener/src/supervisor.rs index 5b2ef7ee90..0164fb50d4 100644 --- a/plugins/listener/src/supervisor.rs +++ b/plugins/listener/src/supervisor.rs @@ -1,33 +1,162 @@ -use ractor::{ActorCell, ActorProcessingErr, ActorRef}; -use ractor_supervisor::dynamic::{ - DynamicSupervisor, DynamicSupervisorMsg, DynamicSupervisorOptions, +use std::path::PathBuf; +use std::time::{Instant, SystemTime}; + +use ractor::concurrency::Duration; +use ractor::{Actor, ActorCell, ActorProcessingErr}; +use ractor_supervisor::core::{ChildBackoffFn, ChildSpec, Restart, SpawnFn}; +use ractor_supervisor::supervisor::{Supervisor, SupervisorArguments, SupervisorOptions}; +use ractor_supervisor::SupervisorStrategy; + +use crate::actors::{ + ChannelMode, ListenerActor, ListenerArgs, RecArgs, RecorderActor, SourceActor, SourceArgs, }; -pub type SupervisorRef = ActorRef; -pub type SupervisorHandle = tokio::task::JoinHandle<()>; +pub const SESSION_SUPERVISOR_PREFIX: &str = "session_supervisor_"; + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, specta::Type)] +pub struct SessionParams { + pub session_id: String, + pub languages: Vec, + pub onboarding: bool, + pub record_enabled: bool, + pub model: String, + pub base_url: String, + pub api_key: String, + pub keywords: Vec, +} -pub const SUPERVISOR_NAME: &str = "listener_supervisor"; +#[derive(Clone)] +pub struct SessionContext { + pub app: tauri::AppHandle, + pub params: SessionParams, + pub app_dir: PathBuf, + pub started_at_instant: Instant, + pub started_at_system: SystemTime, +} + +pub fn session_supervisor_name(session_id: &str) -> String { + format!("{}{}", SESSION_SUPERVISOR_PREFIX, session_id) +} -fn make_supervisor_options() -> DynamicSupervisorOptions { - DynamicSupervisorOptions { - max_children: Some(10), - max_restarts: 50, - max_window: ractor::concurrency::Duration::from_secs(60), - reset_after: Some(ractor::concurrency::Duration::from_secs(30)), +fn make_supervisor_options() -> SupervisorOptions { + SupervisorOptions { + strategy: SupervisorStrategy::RestForOne, + max_restarts: 10, + max_window: Duration::from_secs(30), + reset_after: Some(Duration::from_secs(60)), } } -pub async fn spawn_listener_supervisor( - parent: Option, -) -> Result<(SupervisorRef, SupervisorHandle), ActorProcessingErr> { - let options = make_supervisor_options(); +fn make_listener_backoff() -> ChildBackoffFn { + ChildBackoffFn::new(|_id, count, _, _| { + if count <= 1 { + None + } else { + Some(Duration::from_millis(500 * (1 << count.min(5)))) + } + }) +} + +pub async fn spawn_session_supervisor( + ctx: SessionContext, +) -> Result<(ActorCell, tokio::task::JoinHandle<()>), ActorProcessingErr> { + let supervisor_name = session_supervisor_name(&ctx.params.session_id); + + let mut child_specs = Vec::new(); + + let ctx_source = ctx.clone(); + child_specs.push(ChildSpec { + id: SourceActor::name().to_string(), + restart: Restart::Permanent, + spawn_fn: SpawnFn::new(move |supervisor_cell, _id| { + let ctx = ctx_source.clone(); + async move { + let (actor_ref, _) = Actor::spawn_linked( + Some(SourceActor::name()), + SourceActor, + SourceArgs { + mic_device: None, + onboarding: ctx.params.onboarding, + app: ctx.app.clone(), + session_id: ctx.params.session_id.clone(), + }, + supervisor_cell, + ) + .await?; + Ok(actor_ref.get_cell()) + } + }), + backoff_fn: None, + reset_after: Some(Duration::from_secs(30)), + }); - let (supervisor_ref, handle) = - DynamicSupervisor::spawn(SUPERVISOR_NAME.to_string(), options).await?; + let ctx_listener = ctx.clone(); + child_specs.push(ChildSpec { + id: ListenerActor::name().to_string(), + restart: Restart::Permanent, + spawn_fn: SpawnFn::new(move |supervisor_cell, _id| { + let ctx = ctx_listener.clone(); + async move { + let mode = ChannelMode::Dual; - if let Some(parent_cell) = parent { - supervisor_ref.get_cell().link(parent_cell); + let (actor_ref, _) = Actor::spawn_linked( + Some(ListenerActor::name()), + ListenerActor, + ListenerArgs { + app: ctx.app.clone(), + languages: ctx.params.languages.clone(), + onboarding: ctx.params.onboarding, + model: ctx.params.model.clone(), + base_url: ctx.params.base_url.clone(), + api_key: ctx.params.api_key.clone(), + keywords: ctx.params.keywords.clone(), + mode, + session_started_at: ctx.started_at_instant, + session_started_at_unix: ctx.started_at_system, + session_id: ctx.params.session_id.clone(), + }, + supervisor_cell, + ) + .await?; + Ok(actor_ref.get_cell()) + } + }), + backoff_fn: Some(make_listener_backoff()), + reset_after: Some(Duration::from_secs(30)), + }); + + if ctx.params.record_enabled { + let ctx_recorder = ctx.clone(); + child_specs.push(ChildSpec { + id: RecorderActor::name().to_string(), + restart: Restart::Transient, + spawn_fn: SpawnFn::new(move |supervisor_cell, _id| { + let ctx = ctx_recorder.clone(); + async move { + let (actor_ref, _) = Actor::spawn_linked( + Some(RecorderActor::name()), + RecorderActor, + RecArgs { + app_dir: ctx.app_dir.clone(), + session_id: ctx.params.session_id.clone(), + }, + supervisor_cell, + ) + .await?; + Ok(actor_ref.get_cell()) + } + }), + backoff_fn: None, + reset_after: None, + }); } - Ok((supervisor_ref, handle)) + let args = SupervisorArguments { + child_specs, + options: make_supervisor_options(), + }; + + let (supervisor_ref, handle) = Supervisor::spawn(supervisor_name, args).await?; + + Ok((supervisor_ref.get_cell(), handle)) }