Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions crates/listener-core/src/actors/recorder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub enum RecMsg {
pub struct RecArgs {
pub app_dir: PathBuf,
pub session_id: String,
pub done_tx: Option<tokio::sync::oneshot::Sender<()>>,
}

pub struct RecState {
Expand All @@ -34,6 +35,7 @@ pub struct RecState {
wav_path: PathBuf,
last_flush: Instant,
is_stereo: bool,
done_tx: Option<tokio::sync::oneshot::Sender<()>>,
}

pub struct RecorderActor<E: AudioCodec = Mp3Codec> {
Expand Down Expand Up @@ -134,6 +136,7 @@ impl<E: AudioCodec> Actor for RecorderActor<E> {
wav_path,
last_flush: Instant::now(),
is_stereo,
done_tx: args.done_tx,
})
}

Expand Down Expand Up @@ -207,6 +210,9 @@ impl<E: AudioCodec> Actor for RecorderActor<E> {
}
}

if let Some(tx) = st.done_tx.take() {
let _ = tx.send(());
}
Ok(())
}
}
Expand Down
34 changes: 29 additions & 5 deletions crates/listener-core/src/actors/session/supervisor.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use hypr_supervisor::{RestartBudget, RestartTracker, RetryStrategy, spawn_with_retry};
use ractor::concurrency::Duration;
use ractor::{Actor, ActorCell, ActorProcessingErr, ActorRef, SupervisionEvent};
Expand Down Expand Up @@ -34,6 +36,7 @@ pub struct SessionState {
source_cell: Option<ActorCell>,
listener_cell: Option<ActorCell>,
recorder_cell: Option<ActorCell>,
recorder_done: Option<tokio::sync::oneshot::Receiver<()>>,
source_restarts: RestartTracker,
recorder_restarts: RestartTracker,
shutting_down: bool,
Expand Down Expand Up @@ -74,27 +77,30 @@ impl Actor for SessionActor {
)
.await?;

let recorder_cell = if ctx.params.record_enabled {
let (recorder_cell, recorder_done) = if ctx.params.record_enabled {
let (done_tx, done_rx) = tokio::sync::oneshot::channel();
let (recorder_ref, _): (ActorRef<RecMsg>, _) = Actor::spawn_linked(
Some(RecorderActor::name()),
RecorderActor::new(),
RecArgs {
app_dir: ctx.app_dir.clone(),
session_id: ctx.params.session_id.clone(),
done_tx: Some(done_tx),
},
myself.get_cell(),
)
.await?;
Some(recorder_ref.get_cell())
(Some(recorder_ref.get_cell()), Some(done_rx))
} else {
None
(None, None)
};

Ok(SessionState {
ctx,
source_cell: Some(source_ref.get_cell()),
listener_cell: None,
recorder_cell,
recorder_done,
source_restarts: RestartTracker::new(),
recorder_restarts: RestartTracker::new(),
shutting_down: false,
Expand Down Expand Up @@ -170,8 +176,9 @@ impl Actor for SessionActor {
state.shutting_down = true;

if let Some(cell) = state.recorder_cell.take() {
let done = state.recorder_done.take();
cell.stop(Some("session_stop".to_string()));
lifecycle::wait_for_actor_shutdown(RecorderActor::name()).await;
wait_for_recorder_done(done).await;
}

if let Some(cell) = state.source_cell.take() {
Expand Down Expand Up @@ -367,18 +374,22 @@ async fn try_restart_recorder(supervisor_cell: ActorCell, state: &mut SessionSta
let sup = supervisor_cell;
let app_dir = state.ctx.app_dir.clone();
let session_id = state.ctx.params.session_id.clone();
let (done_tx, done_rx) = tokio::sync::oneshot::channel();
let done_tx = Arc::new(std::sync::Mutex::new(Some(done_tx)));

let cell = spawn_with_retry(&RETRY_STRATEGY, || {
let sup = sup.clone();
let app_dir = app_dir.clone();
let session_id = session_id.clone();
let done_tx = done_tx.lock().unwrap().take();
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recorder retry loses done signal, skipping shutdown wait

Medium Severity

In try_restart_recorder, done_tx is taken from the Arc<Mutex<Option<...>>> on the first retry call via .take(). If that attempt fails, the sender is dropped (closing the channel), and subsequent retries get None. If a later retry succeeds, the recorder actor has done_tx: None and can never signal completion. Meanwhile state.recorder_done holds a done_rx whose sender was already dropped, so wait_for_recorder_done resolves immediately instead of waiting for the recorder to finish encoding — risking data loss of the WAV-to-MP3 conversion on shutdown.

Additional Locations (1)

Fix in Cursor Fix in Web

async move {
let (r, _): (ActorRef<RecMsg>, _) = Actor::spawn_linked(
Some(RecorderActor::name()),
RecorderActor::new(),
RecArgs {
app_dir,
session_id,
done_tx,
},
sup,
)
Expand All @@ -391,6 +402,7 @@ async fn try_restart_recorder(supervisor_cell: ActorCell, state: &mut SessionSta
match cell {
Some(c) => {
state.recorder_cell = Some(c);
state.recorder_done = Some(done_rx);
true
}
None => false,
Expand All @@ -407,12 +419,24 @@ async fn meltdown(myself: ActorRef<SessionMsg>, state: &mut SessionState) {
cell.stop(Some("meltdown".to_string()));
}
if let Some(cell) = state.recorder_cell.take() {
let done = state.recorder_done.take();
cell.stop(Some("meltdown".to_string()));
lifecycle::wait_for_actor_shutdown(RecorderActor::name()).await;
wait_for_recorder_done(done).await;
}
myself.stop(Some("restart_limit_exceeded".to_string()));
}

async fn wait_for_recorder_done(done: Option<tokio::sync::oneshot::Receiver<()>>) {
match done {
Some(rx) => {
tokio::time::timeout(Duration::from_secs(30), rx).await.ok();
}
None => {
lifecycle::wait_for_actor_shutdown(RecorderActor::name()).await;
}
}
}

fn classify_connection_failure(base_url: &str) -> String {
if base_url.contains("localhost") || base_url.contains("127.0.0.1") {
"Local transcription server is not running".to_string()
Expand Down
36 changes: 36 additions & 0 deletions crates/owhisper-client/src/providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,42 @@ impl Provider {
}
}

pub fn translate_control_message(
&self,
msg: &owhisper_interface::ControlMessage,
) -> Option<String> {
use crate::adapter::RealtimeSttAdapter;
use hypr_ws_client::client::Message;
use owhisper_interface::ControlMessage;

fn extract_text(msg: Message) -> Option<String> {
match msg {
Message::Text(t) => Some(t.to_string()),
_ => None,
}
}

fn from_adapter(adapter: &impl RealtimeSttAdapter, msg: &ControlMessage) -> Option<String> {
match msg {
ControlMessage::KeepAlive => adapter.keep_alive_message().and_then(extract_text),
ControlMessage::Finalize => extract_text(adapter.finalize_message()),
ControlMessage::CloseStream => None,
}
}

match self {
Self::Deepgram => from_adapter(&crate::adapter::DeepgramAdapter, msg),
Self::AssemblyAI => from_adapter(&crate::adapter::AssemblyAIAdapter, msg),
Self::Soniox => from_adapter(&crate::adapter::SonioxAdapter, msg),
Self::Fireworks => from_adapter(&crate::adapter::FireworksAdapter, msg),
Self::OpenAI => from_adapter(&crate::adapter::OpenAIAdapter, msg),
Self::Gladia => from_adapter(&crate::adapter::GladiaAdapter, msg),
Self::ElevenLabs => from_adapter(&crate::adapter::ElevenLabsAdapter, msg),
Self::DashScope => from_adapter(&crate::adapter::DashScopeAdapter, msg),
Self::Mistral => from_adapter(&crate::adapter::MistralAdapter::default(), msg),
}
}

pub fn detect_error(&self, data: &[u8]) -> Option<ProviderError> {
match self {
Self::Deepgram => deepgram::error::detect_error(data),
Expand Down
16 changes: 15 additions & 1 deletion crates/transcribe-proxy/src/relay/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use owhisper_client::Auth;
pub use tokio_tungstenite::tungstenite::ClientRequestBuilder;

use super::handler::WebSocketProxy;
use super::types::{FirstMessageTransformer, InitialMessage, OnCloseCallback, ResponseTransformer};
use super::types::{
ClientMessageFilter, FirstMessageTransformer, InitialMessage, OnCloseCallback,
ResponseTransformer,
};
use crate::config::DEFAULT_CONNECT_TIMEOUT_MS;
use crate::provider_selector::SelectedProvider;

Expand Down Expand Up @@ -34,6 +37,7 @@ pub struct WebSocketProxyBuilder<S = NoUpstream> {
response_transformer: Option<ResponseTransformer>,
connect_timeout: Duration,
on_close: Option<OnCloseCallback>,
client_message_filter: Option<ClientMessageFilter>,
}

impl Default for WebSocketProxyBuilder<NoUpstream> {
Expand All @@ -46,6 +50,7 @@ impl Default for WebSocketProxyBuilder<NoUpstream> {
response_transformer: None,
connect_timeout: Duration::from_millis(DEFAULT_CONNECT_TIMEOUT_MS),
on_close: None,
client_message_filter: None,
}
}
}
Expand All @@ -60,6 +65,7 @@ impl<S> WebSocketProxyBuilder<S> {
response_transformer: self.response_transformer,
connect_timeout: self.connect_timeout,
on_close: self.on_close,
client_message_filter: self.client_message_filter,
}
}

Expand All @@ -71,6 +77,7 @@ impl<S> WebSocketProxyBuilder<S> {
response_transformer: Option<ResponseTransformer>,
connect_timeout: Duration,
on_close: Option<OnCloseCallback>,
client_message_filter: Option<ClientMessageFilter>,
) -> WebSocketProxy {
let control_message_types = if control_message_types.is_empty() {
None
Expand All @@ -86,6 +93,7 @@ impl<S> WebSocketProxyBuilder<S> {
response_transformer,
connect_timeout,
on_close,
client_message_filter,
)
}

Expand Down Expand Up @@ -131,6 +139,11 @@ impl<S> WebSocketProxyBuilder<S> {
}));
self
}

pub fn client_message_filter(mut self, filter: ClientMessageFilter) -> Self {
self.client_message_filter = Some(filter);
self
}
}

impl WebSocketProxyBuilder<NoUpstream> {
Expand Down Expand Up @@ -194,6 +207,7 @@ impl WebSocketProxyBuilder<WithUrl> {
self.response_transformer,
self.connect_timeout,
self.on_close,
self.client_message_filter,
))
}
}
Expand Down
23 changes: 21 additions & 2 deletions crates/transcribe-proxy/src/relay/channel_split.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ use tokio_tungstenite::{

use owhisper_client::Provider;

use super::types::{InitialMessage, OnCloseCallback, ResponseTransformer, convert};
use super::types::{
ClientMessageFilter, InitialMessage, OnCloseCallback, ResponseTransformer, convert,
};

const SAMPLE_BYTES: usize = 2;
const FRAME_BYTES: usize = SAMPLE_BYTES * 2;
Expand Down Expand Up @@ -71,6 +73,7 @@ pub struct ChannelSplitProxy {
response_transformer: Option<ResponseTransformer>,
connect_timeout: Duration,
on_close: Option<OnCloseCallback>,
client_message_filter: Option<ClientMessageFilter>,
}

impl ChannelSplitProxy {
Expand Down Expand Up @@ -106,9 +109,15 @@ impl ChannelSplitProxy {
response_transformer,
connect_timeout,
on_close,
client_message_filter: None,
}
}

pub fn with_client_message_filter(mut self, filter: ClientMessageFilter) -> Self {
self.client_message_filter = Some(filter);
self
}

async fn connect_upstream(
request: &ClientRequestBuilder,
timeout: Duration,
Expand Down Expand Up @@ -155,6 +164,7 @@ impl ChannelSplitProxy {
spk_upstream,
self.initial_message.clone(),
self.response_transformer.clone(),
self.client_message_filter.clone(),
)
.await;

Expand All @@ -177,6 +187,7 @@ impl ChannelSplitProxy {
spk_upstream: WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
initial_message: Option<InitialMessage>,
response_transformer: Option<ResponseTransformer>,
client_message_filter: Option<ClientMessageFilter>,
) {
let (mut mic_tx, mut mic_rx) = mic_upstream.split();
let (mut spk_tx, mut spk_rx) = spk_upstream.split();
Expand Down Expand Up @@ -218,7 +229,15 @@ impl ChannelSplitProxy {
}
}
Message::Text(text) => {
let tung = TungsteniteMessage::Text(text.to_string().into());
let text_str = text.to_string();
let forwarded = match client_message_filter.as_ref() {
Some(filter) => match filter(text_str) {
Some(s) => s,
None => continue,
},
None => text_str,
};
let tung = TungsteniteMessage::Text(forwarded.into());
if mic_tx.send(tung.clone()).await.is_err()
|| spk_tx.send(tung).await.is_err()
{
Expand Down
Loading
Loading