diff --git a/sdk/rust/Cargo.toml b/sdk/rust/Cargo.toml index 7ec7823a..94794697 100644 --- a/sdk/rust/Cargo.toml +++ b/sdk/rust/Cargo.toml @@ -22,6 +22,7 @@ serde_json = "1" thiserror = "2" tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync"] } tokio-stream = "0.1" +tokio-util = "0.7" futures-core = "0.3" reqwest = { version = "0.12", features = ["json"] } urlencoding = "2" diff --git a/sdk/rust/src/detail/core_interop.rs b/sdk/rust/src/detail/core_interop.rs index 43884d7f..0d17fe62 100644 --- a/sdk/rust/src/detail/core_interop.rs +++ b/sdk/rust/src/detail/core_interop.rs @@ -48,6 +48,19 @@ impl ResponseBuffer { } } +/// Request buffer with binary payload for `execute_command_with_binary`. +/// +/// Used for audio streaming — carries both JSON params and raw PCM bytes. +#[repr(C)] +struct StreamingRequestBuffer { + command: *const i8, + command_length: i32, + data: *const i8, + data_length: i32, + binary_data: *const u8, + binary_data_length: i32, +} + /// Signature for `execute_command`. type ExecuteCommandFn = unsafe extern "C" fn(*const RequestBuffer, *mut ResponseBuffer); @@ -63,6 +76,10 @@ type ExecuteCommandWithCallbackFn = unsafe extern "C" fn( *mut std::ffi::c_void, ); +/// Signature for `execute_command_with_binary`. +type ExecuteCommandWithBinaryFn = + unsafe extern "C" fn(*const StreamingRequestBuffer, *mut ResponseBuffer); + // ── Library name helpers ───────────────────────────────────────────────────── #[cfg(target_os = "windows")] @@ -237,6 +254,8 @@ pub(crate) struct CoreInterop { CallbackFn, *mut std::ffi::c_void, ), + execute_command_with_binary: + Option, } impl std::fmt::Debug for CoreInterop { @@ -307,12 +326,22 @@ impl CoreInterop { *sym }; + // SAFETY: Same as above — symbol must match `ExecuteCommandWithBinaryFn`. + // Optional: older native cores may not export this symbol (used for audio streaming). + let execute_command_with_binary: Option = unsafe { + library + .get::(b"execute_command_with_binary\0") + .ok() + .map(|sym| *sym) + }; + Ok(Self { _library: library, #[cfg(target_os = "windows")] _dependency_libs, execute_command, execute_command_with_callback, + execute_command_with_binary, }) } @@ -354,6 +383,61 @@ impl CoreInterop { Self::process_response(response) } + /// Execute a command with an additional binary payload. + /// + /// Used for audio streaming — `binary_data` carries raw PCM bytes + /// alongside the JSON parameters. + pub fn execute_command_with_binary( + &self, + command: &str, + params: Option<&Value>, + binary_data: &[u8], + ) -> Result { + let native_fn = self.execute_command_with_binary.ok_or_else(|| { + FoundryLocalError::CommandExecution { + reason: "execute_command_with_binary is not supported by this native core \ + (symbol not found)" + .into(), + } + })?; + + let cmd = CString::new(command).map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("Invalid command string: {e}"), + })?; + + let data_json = match params { + Some(v) => serde_json::to_string(v)?, + None => String::new(), + }; + let data_cstr = + CString::new(data_json.as_str()).map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("Invalid data string: {e}"), + })?; + + let request = StreamingRequestBuffer { + command: cmd.as_ptr(), + command_length: cmd.as_bytes().len() as i32, + data: data_cstr.as_ptr(), + data_length: data_cstr.as_bytes().len() as i32, + binary_data: if binary_data.is_empty() { + std::ptr::null() + } else { + binary_data.as_ptr() + }, + binary_data_length: binary_data.len() as i32, + }; + + let mut response = ResponseBuffer::new(); + + // SAFETY: `request` fields point into `cmd`, `data_cstr`, and + // `binary_data` which are all alive for the duration of this call. + unsafe { + (native_fn)(&request, &mut response); + } + + Self::process_response(response) + } + /// Execute a command that streams results back via `callback`. /// /// Each chunk delivered by the native library is decoded as UTF-8 and diff --git a/sdk/rust/src/lib.rs b/sdk/rust/src/lib.rs index 872a875c..9fb4bb85 100644 --- a/sdk/rust/src/lib.rs +++ b/sdk/rust/src/lib.rs @@ -31,8 +31,10 @@ pub use async_openai::types::chat::{ // Re-export OpenAI response types for convenience. pub use crate::openai::{ - AudioTranscriptionResponse, AudioTranscriptionStream, ChatCompletionStream, - TranscriptionSegment, TranscriptionWord, + AudioTranscriptionResponse, AudioTranscriptionStream, ChatCompletionStream, ContentPart, + CoreErrorResponse, LiveAudioTranscriptionOptions, LiveAudioTranscriptionResponse, + LiveAudioTranscriptionSession, LiveAudioTranscriptionStream, TranscriptionSegment, + TranscriptionWord, }; pub use async_openai::types::chat::{ ChatChoice, ChatChoiceStream, ChatCompletionMessageToolCall, diff --git a/sdk/rust/src/openai/audio_client.rs b/sdk/rust/src/openai/audio_client.rs index 0319da38..cc1813d0 100644 --- a/sdk/rust/src/openai/audio_client.rs +++ b/sdk/rust/src/openai/audio_client.rs @@ -9,6 +9,7 @@ use crate::detail::core_interop::CoreInterop; use crate::error::{FoundryLocalError, Result}; use super::json_stream::JsonStream; +use super::live_audio_client::LiveAudioTranscriptionSession; /// A segment of a transcription, as returned by the OpenAI-compatible API. #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] @@ -196,6 +197,15 @@ impl AudioClient { Ok(AudioTranscriptionStream::new(rx)) } + /// Create a [`LiveAudioTranscriptionSession`] for real-time audio + /// streaming transcription. + /// + /// Configure the session's [`settings`](LiveAudioTranscriptionSession::settings) + /// before calling [`start`](LiveAudioTranscriptionSession::start). + pub fn create_live_transcription_session(&self) -> LiveAudioTranscriptionSession { + LiveAudioTranscriptionSession::new(&self.model_id, Arc::clone(&self.core)) + } + fn validate_path(path: &str) -> Result<()> { if path.trim().is_empty() { return Err(FoundryLocalError::Validation { diff --git a/sdk/rust/src/openai/live_audio_client.rs b/sdk/rust/src/openai/live_audio_client.rs new file mode 100644 index 00000000..8b285a96 --- /dev/null +++ b/sdk/rust/src/openai/live_audio_client.rs @@ -0,0 +1,698 @@ +//! Live audio transcription streaming session. +//! +//! Provides real-time audio streaming ASR (Automatic Speech Recognition). +//! Audio data from a microphone (or other source) is pushed in as PCM chunks +//! and transcription results are returned as an async [`Stream`](futures_core::Stream). +//! +//! # Example +//! +//! ```ignore +//! let audio_client = model.create_audio_client(); +//! let mut session = audio_client.create_live_transcription_session(); +//! session.settings.sample_rate = 16000; +//! session.settings.channels = 1; +//! session.settings.language = Some("en".into()); +//! +//! session.start(None).await?; +//! +//! // Push audio from microphone callback +//! session.append(&pcm_bytes, None).await?; +//! +//! // Read results as async stream +//! use tokio_stream::StreamExt; +//! let mut stream = session.get_transcription_stream().await?; +//! while let Some(result) = stream.next().await { +//! let result = result?; +//! print!("{}", result.content[0].text); +//! } +//! +//! session.stop(None).await?; +//! ``` + +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use serde_json::json; +use tokio_util::sync::CancellationToken; + +use crate::detail::core_interop::CoreInterop; +use crate::error::{FoundryLocalError, Result}; + +// ── Types ──────────────────────────────────────────────────────────────────── + +/// Audio format settings for a live transcription session. +/// +/// Must be configured before calling [`LiveAudioTranscriptionSession::start`]. +/// Settings are frozen once the session starts. +#[derive(Debug, Clone)] +pub struct LiveAudioTranscriptionOptions { + /// PCM sample rate in Hz. Default: 16000. + pub sample_rate: u32, + /// Number of audio channels. Default: 1 (mono). + pub channels: u32, + /// Number of bits per audio sample. Default: 16. + pub bits_per_sample: u32, + /// Optional BCP-47 language hint (e.g., `"en"`, `"zh"`). + pub language: Option, + /// Maximum number of audio chunks buffered in the internal push queue. + /// If the queue is full, [`LiveAudioTranscriptionSession::append`] will + /// wait asynchronously. + /// Default: 100 (~3 seconds of audio at typical chunk sizes). + pub push_queue_capacity: usize, +} + +impl Default for LiveAudioTranscriptionOptions { + fn default() -> Self { + Self { + sample_rate: 16000, + channels: 1, + bits_per_sample: 16, + language: None, + push_queue_capacity: 100, + } + } +} + +/// Internal raw deserialization target matching the native core's JSON format. +#[derive(Debug, Clone, serde::Deserialize)] +struct LiveAudioTranscriptionRaw { + #[serde(default)] + is_final: bool, + #[serde(default)] + text: String, + start_time: Option, + end_time: Option, +} + +/// A content part within a [`LiveAudioTranscriptionResponse`]. +/// +/// Mirrors the C# `ContentPart` shape from the OpenAI Realtime API so that +/// callers can access `result.content[0].text` or `result.content[0].transcript` +/// consistently across SDKs. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct ContentPart { + /// The transcribed text. + pub text: String, + /// Same as `text` — provided for OpenAI Realtime API compatibility. + pub transcript: String, +} + +/// Transcription result from a live audio streaming session. +/// +/// Shaped to match the C# `LiveAudioTranscriptionResponse : ConversationItem` +/// so that callers access text via `result.content[0].text` or +/// `result.content[0].transcript`. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct LiveAudioTranscriptionResponse { + /// Content parts — typically a single element. Access text via + /// `result.content[0].text` or `result.content[0].transcript`. + pub content: Vec, + /// Whether this is a final or partial (interim) result. + /// Nemotron models always return `true`; other models may return `false` + /// for interim hypotheses that will be replaced by a subsequent final result. + pub is_final: bool, + /// Start time offset of this segment in the audio stream (seconds). + pub start_time: Option, + /// End time offset of this segment in the audio stream (seconds). + pub end_time: Option, +} + +impl LiveAudioTranscriptionResponse { + /// Parse a transcription response from the native core's JSON format. + pub fn from_json(json: &str) -> Result { + serde_json::from_str::(json) + .map(Self::from_raw) + .map_err(FoundryLocalError::from) + } + + fn from_raw(raw: LiveAudioTranscriptionRaw) -> Self { + Self { + content: vec![ContentPart { + transcript: raw.text.clone(), + text: raw.text, + }], + is_final: raw.is_final, + start_time: raw.start_time, + end_time: raw.end_time, + } + } +} + +/// Structured error response from the native core. +#[derive(Debug, Clone, serde::Deserialize)] +pub struct CoreErrorResponse { + /// Error code (e.g. `"ASR_SESSION_NOT_FOUND"`). + pub code: String, + /// Human-readable error message. + pub message: String, + /// Whether this error is transient (retryable). + #[serde(rename = "isTransient", default)] + pub is_transient: bool, +} + +impl CoreErrorResponse { + /// Attempt to parse a native error string as structured JSON. + /// Returns `None` if the error is not valid JSON or doesn't match the schema. + pub fn try_parse(error_string: &str) -> Option { + serde_json::from_str(error_string).ok() + } +} + +// ── Stream type ────────────────────────────────────────────────────────────── + +/// An async stream of [`LiveAudioTranscriptionResponse`] items. +/// +/// Returned by [`LiveAudioTranscriptionSession::get_transcription_stream`]. +/// Implements [`futures_core::Stream`]. +pub struct LiveAudioTranscriptionStream { + rx: tokio::sync::mpsc::UnboundedReceiver>, +} + +impl futures_core::Stream for LiveAudioTranscriptionStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.rx.poll_recv(cx) + } +} + +// ── Session state ──────────────────────────────────────────────────────────── + +struct SessionState { + session_handle: Option, + started: bool, + stopped: bool, + push_tx: Option>>, + output_tx: Option>>, + output_rx: Option>>, + push_loop_handle: Option>, +} + +impl SessionState { + fn new() -> Self { + Self { + session_handle: None, + started: false, + stopped: false, + push_tx: None, + output_tx: None, + output_rx: None, + push_loop_handle: None, + } + } +} + +// ── Session ────────────────────────────────────────────────────────────────── + +/// Session for real-time audio streaming ASR (Automatic Speech Recognition). +/// +/// Audio data from a microphone (or other source) is pushed in as PCM chunks +/// via [`append`](Self::append), and transcription results are returned as an +/// async [`Stream`](futures_core::Stream) via +/// [`get_transcription_stream`](Self::get_transcription_stream). +/// +/// Created via [`AudioClient::create_live_transcription_session`](super::AudioClient::create_live_transcription_session). +/// +/// # Thread safety +/// +/// [`append`](Self::append) can be called from any thread (including +/// high-frequency audio callbacks). Pushes are internally serialized via a +/// bounded channel to prevent unbounded memory growth and ensure ordering. +/// +/// # Cancellation +/// +/// All lifecycle methods accept an optional [`CancellationToken`]. Pass `None` +/// to use the default (no cancellation). +pub struct LiveAudioTranscriptionSession { + model_id: String, + core: Arc, + /// Audio format settings. Must be configured before calling [`start`](Self::start). + /// Settings are frozen once the session starts. + pub settings: LiveAudioTranscriptionOptions, + state: tokio::sync::Mutex, +} + +impl LiveAudioTranscriptionSession { + pub(crate) fn new(model_id: &str, core: Arc) -> Self { + Self { + model_id: model_id.to_owned(), + core, + settings: LiveAudioTranscriptionOptions::default(), + state: tokio::sync::Mutex::new(SessionState::new()), + } + } + + /// Start a real-time audio streaming session. + /// + /// Must be called before [`append`](Self::append) or + /// [`get_transcription_stream`](Self::get_transcription_stream). + /// Settings are frozen after this call. + /// + /// # Cancellation + /// + /// Pass a [`CancellationToken`] to abort the start operation. If + /// cancelled, any native session that was created is cleaned up + /// automatically. + pub async fn start(&self, ct: Option) -> Result<()> { + let mut state = self.state.lock().await; + + if state.started { + return Err(FoundryLocalError::Validation { + reason: "Streaming session already started. Call stop() first.".into(), + }); + } + + let active_settings = self.settings.clone(); + + let (output_tx, output_rx) = + tokio::sync::mpsc::unbounded_channel::>(); + let (push_tx, push_rx) = + tokio::sync::mpsc::channel::>(active_settings.push_queue_capacity); + + let request = self.build_start_request(&active_settings); + + let core = Arc::clone(&self.core); + let start_future = tokio::task::spawn_blocking(move || { + core.execute_command("audio_stream_start", Some(&request)) + }); + + let session_handle = self.await_start(start_future, ct).await?; + + if session_handle.is_empty() { + return Err(FoundryLocalError::CommandExecution { + reason: "Native core did not return a session handle.".into(), + }); + } + + let push_loop_core = Arc::clone(&self.core); + let push_loop_output_tx = output_tx.clone(); + let handle_clone = session_handle.clone(); + let push_loop_handle = tokio::task::spawn_blocking(move || { + Self::push_loop(push_loop_core, handle_clone, push_rx, push_loop_output_tx); + }); + + state.session_handle = Some(session_handle); + state.started = true; + state.stopped = false; + state.push_tx = Some(push_tx); + state.output_tx = Some(output_tx); + state.output_rx = Some(output_rx); + state.push_loop_handle = Some(push_loop_handle); + + Ok(()) + } + + /// Push a chunk of raw PCM audio data to the streaming session. + /// + /// Can be called from any async context (including high-frequency audio + /// callbacks when wrapped). Chunks are internally queued and serialized to + /// the native core. + /// + /// The data is copied internally so the caller can reuse the buffer. + /// + /// # Cancellation + /// + /// Pass a [`CancellationToken`] to abort if the push queue is full + /// (backpressure). The audio chunk will not be queued if cancelled. + pub async fn append(&self, pcm_data: &[u8], ct: Option) -> Result<()> { + // Clone the sender while holding the lock, then drop the lock before + // awaiting the send. This prevents deadlock when the bounded push + // queue is full — stop() can still acquire the lock to close the + // channel and unblock the send. + let tx = { + let state = self.state.lock().await; + + if !state.started || state.stopped { + return Err(FoundryLocalError::Validation { + reason: "No active streaming session. Call start() first.".into(), + }); + } + + state + .push_tx + .clone() + .ok_or_else(|| FoundryLocalError::Internal { + reason: "Push channel not available — session may be in an invalid state" + .into(), + })? + }; + + let data = pcm_data.to_vec(); + + if let Some(token) = &ct { + tokio::select! { + result = tx.send(data) => { + result.map_err(|_| FoundryLocalError::CommandExecution { + reason: "Push channel closed — session has been stopped".into(), + }) + } + _ = token.cancelled() => { + Err(FoundryLocalError::CommandExecution { + reason: "Append cancelled".into(), + }) + } + } + } else { + tx.send(data) + .await + .map_err(|_| FoundryLocalError::CommandExecution { + reason: "Push channel closed — session has been stopped".into(), + }) + } + } + + /// Get the async stream of transcription results. + /// + /// Results arrive as the native ASR engine processes audio data. + /// Can only be called once per session (the receiver is moved out). + pub async fn get_transcription_stream(&self) -> Result { + let mut state = self.state.lock().await; + + let rx = state + .output_rx + .take() + .ok_or_else(|| FoundryLocalError::Validation { + reason: "No active streaming session, or stream already taken. \ + Call start() first and only call get_transcription_stream() once." + .into(), + })?; + + Ok(LiveAudioTranscriptionStream { rx }) + } + + /// Signal end-of-audio and stop the streaming session. + /// + /// Any remaining buffered audio in the push queue will be drained to the + /// native core first. Final results are delivered through the transcription + /// stream before it completes. + /// + /// # Cancellation safety + /// + /// Even if the provided [`CancellationToken`] fires, the native session + /// stop is always completed to avoid native session leaks (matching the C# + /// `StopAsync` cancellation-safe pattern). + pub async fn stop(&self, ct: Option) -> Result<()> { + let mut state = self.state.lock().await; + + if !state.started || state.stopped { + return Ok(()); + } + + state.stopped = true; + + self.drain_push_loop(&mut state).await; + let stop_result = self.stop_native_session(&state, ct).await; + Self::write_final_result(&stop_result, &state); + self.finalize_state(&mut state); + + stop_result?; + Ok(()) + } + + // ── Private helpers ────────────────────────────────────────────────── + + /// Build the JSON request for `audio_stream_start`. + fn build_start_request(&self, settings: &LiveAudioTranscriptionOptions) -> serde_json::Value { + let mut params = json!({ + "Model": self.model_id, + "SampleRate": settings.sample_rate.to_string(), + "Channels": settings.channels.to_string(), + "BitsPerSample": settings.bits_per_sample.to_string(), + }); + if let Some(ref lang) = settings.language { + params["Language"] = json!(lang); + } + json!({ "Params": params }) + } + + /// Await the start future with cancellation safety. If cancelled, any + /// native session that was already created is cleaned up via + /// `audio_stream_stop`. + async fn await_start( + &self, + start_future: tokio::task::JoinHandle>, + ct: Option, + ) -> Result { + // Always await the start future — we cannot drop it because the + // spawn_blocking task may create a native session that would leak. + let join_result = start_future + .await + .map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("Start audio stream task join error: {e}"), + })?; + + // If a cancellation token was provided and is already cancelled, + // clean up any native session that was created and return an error. + if let Some(token) = ct { + if token.is_cancelled() { + if let Ok(ref handle) = join_result { + if !handle.is_empty() { + let params = json!({ + "Params": { "SessionHandle": handle } + }); + let _ = self + .core + .execute_command("audio_stream_stop", Some(¶ms)); + } + } + return Err(FoundryLocalError::CommandExecution { + reason: "Start cancelled".into(), + }); + } + } + + join_result + } + + /// Close the push channel and wait for the push loop to drain. + async fn drain_push_loop(&self, state: &mut SessionState) { + state.push_tx.take(); + if let Some(handle) = state.push_loop_handle.take() { + let _ = handle.await; + } + } + + /// Tell the native core to stop the audio stream session. Always completes + /// even if the cancellation token fires. + async fn stop_native_session( + &self, + state: &SessionState, + _ct: Option, + ) -> Result { + let session_handle = state + .session_handle + .as_ref() + .ok_or_else(|| FoundryLocalError::Internal { + reason: "Session handle missing during stop".into(), + })? + .clone(); + + let params = json!({ "Params": { "SessionHandle": session_handle } }); + let core = Arc::clone(&self.core); + + // Always await the native stop to completion regardless of cancellation. + // This prevents double-stop and native session leaks. + tokio::task::spawn_blocking(move || { + core.execute_command("audio_stream_stop", Some(¶ms)) + }) + .await + .map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("Stop audio stream task join error: {e}"), + })? + } + + /// Write a final transcription result from a stop response into the output channel. + fn write_final_result(stop_result: &Result, state: &SessionState) { + let _ = stop_result + .as_ref() + .ok() + .filter(|d| !d.is_empty()) + .and_then(|d| serde_json::from_str::(d).ok()) + .filter(|r| !r.text.is_empty()) + .and_then(|raw| { + state.output_tx.as_ref().map(|tx| { + let _ = tx.send(Ok(LiveAudioTranscriptionResponse::from_raw(raw))); + }) + }); + } + + /// Clean up session state after stop. + fn finalize_state(&self, state: &mut SessionState) { + state.output_tx.take(); + state.session_handle = None; + state.started = false; + } + + /// Internal push loop — runs entirely on a blocking thread. + /// + /// Drains the push queue and sends chunks to the native core one at a time. + /// Terminates the session on any native error. + fn push_loop( + core: Arc, + session_handle: String, + mut push_rx: tokio::sync::mpsc::Receiver>, + output_tx: tokio::sync::mpsc::UnboundedSender>, + ) { + while let Some(audio_data) = push_rx.blocking_recv() { + let params = json!({ + "Params": { "SessionHandle": &session_handle } + }); + + let data = match core.execute_command_with_binary( + "audio_stream_push", + Some(¶ms), + &audio_data, + ) { + Ok(d) => d, + Err(e) => { + let code = match &e { + FoundryLocalError::CommandExecution { reason } => { + CoreErrorResponse::try_parse(reason) + .map(|ei| ei.code) + .unwrap_or_else(|| "UNKNOWN".into()) + } + _ => "UNKNOWN".into(), + }; + let _ = output_tx.send(Err(FoundryLocalError::CommandExecution { + reason: format!("Push failed (code={code}): {e}"), + })); + // Fatal push failures are terminal for the transcription stream. + // Drop the sender and return so the stream completes. + drop(output_tx); + return; + } + }; + + if let Ok(raw) = serde_json::from_str::(&data) { + if !raw.text.is_empty() { + let _ = output_tx.send(Ok(LiveAudioTranscriptionResponse::from_raw(raw))); + } + } + } + } +} + +// ── Drop impl ──────────────────────────────────────────────────────────────── + +impl Drop for LiveAudioTranscriptionSession { + fn drop(&mut self) { + if let Ok(mut state) = self.state.try_lock() { + state.push_tx.take(); + state.output_tx.take(); + + if state.started && !state.stopped { + if let Some(ref handle) = state.session_handle { + let params = serde_json::json!({ + "Params": { "SessionHandle": handle } + }); + let _ = self + .core + .execute_command("audio_stream_stop", Some(¶ms)); + } + state.session_handle = None; + state.started = false; + state.stopped = true; + } + } + } +} + +// ── Tests ──────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn from_json_parses_text_and_is_final() { + let json = r#"{"is_final":true,"text":"hello world","start_time":null,"end_time":null}"#; + let result = LiveAudioTranscriptionResponse::from_json(json).unwrap(); + + assert_eq!(result.content.len(), 1); + assert_eq!(result.content[0].text, "hello world"); + assert_eq!(result.content[0].transcript, "hello world"); + assert!(result.is_final); + } + + #[test] + fn from_json_maps_timing_fields() { + let json = r#"{"is_final":false,"text":"partial","start_time":1.5,"end_time":3.0}"#; + let result = LiveAudioTranscriptionResponse::from_json(json).unwrap(); + + assert_eq!(result.content[0].text, "partial"); + assert!(!result.is_final); + assert_eq!(result.start_time, Some(1.5)); + assert_eq!(result.end_time, Some(3.0)); + } + + #[test] + fn from_json_empty_text_parses_successfully() { + let json = r#"{"is_final":true,"text":"","start_time":null,"end_time":null}"#; + let result = LiveAudioTranscriptionResponse::from_json(json).unwrap(); + + assert_eq!(result.content[0].text, ""); + assert!(result.is_final); + } + + #[test] + fn from_json_only_start_time_sets_start_time() { + let json = r#"{"is_final":true,"text":"word","start_time":2.0,"end_time":null}"#; + let result = LiveAudioTranscriptionResponse::from_json(json).unwrap(); + + assert_eq!(result.start_time, Some(2.0)); + assert_eq!(result.end_time, None); + assert_eq!(result.content[0].text, "word"); + } + + #[test] + fn from_json_invalid_json_returns_error() { + let result = LiveAudioTranscriptionResponse::from_json("not valid json"); + assert!(result.is_err()); + } + + #[test] + fn from_json_content_has_text_and_transcript() { + let json = r#"{"is_final":true,"text":"test","start_time":null,"end_time":null}"#; + let result = LiveAudioTranscriptionResponse::from_json(json).unwrap(); + + assert_eq!(result.content[0].text, "test"); + assert_eq!(result.content[0].transcript, "test"); + } + + #[test] + fn options_default_values() { + let options = LiveAudioTranscriptionOptions::default(); + + assert_eq!(options.sample_rate, 16000); + assert_eq!(options.channels, 1); + assert_eq!(options.bits_per_sample, 16); + assert_eq!(options.language, None); + assert_eq!(options.push_queue_capacity, 100); + } + + #[test] + fn core_error_response_try_parse_valid_json() { + let json = + r#"{"code":"ASR_SESSION_NOT_FOUND","message":"Session not found","isTransient":false}"#; + let error = CoreErrorResponse::try_parse(json).unwrap(); + + assert_eq!(error.code, "ASR_SESSION_NOT_FOUND"); + assert_eq!(error.message, "Session not found"); + assert!(!error.is_transient); + } + + #[test] + fn core_error_response_try_parse_invalid_json_returns_none() { + let result = CoreErrorResponse::try_parse("not json"); + assert!(result.is_none()); + } + + #[test] + fn core_error_response_try_parse_transient_error() { + let json = r#"{"code":"BUSY","message":"Model busy","isTransient":true}"#; + let error = CoreErrorResponse::try_parse(json).unwrap(); + + assert!(error.is_transient); + } +} diff --git a/sdk/rust/src/openai/mod.rs b/sdk/rust/src/openai/mod.rs index 5c17a0df..ae0f1996 100644 --- a/sdk/rust/src/openai/mod.rs +++ b/sdk/rust/src/openai/mod.rs @@ -2,6 +2,7 @@ mod audio_client; mod chat_client; mod embedding_client; mod json_stream; +mod live_audio_client; pub use self::audio_client::{ AudioClient, AudioClientSettings, AudioTranscriptionResponse, AudioTranscriptionStream, @@ -10,3 +11,7 @@ pub use self::audio_client::{ pub use self::chat_client::{ChatClient, ChatClientSettings, ChatCompletionStream}; pub use self::embedding_client::EmbeddingClient; pub use self::json_stream::JsonStream; +pub use self::live_audio_client::{ + ContentPart, CoreErrorResponse, LiveAudioTranscriptionOptions, LiveAudioTranscriptionResponse, + LiveAudioTranscriptionSession, LiveAudioTranscriptionStream, +}; diff --git a/sdk/rust/tests/integration/live_audio_test.rs b/sdk/rust/tests/integration/live_audio_test.rs new file mode 100644 index 00000000..4961d83b --- /dev/null +++ b/sdk/rust/tests/integration/live_audio_test.rs @@ -0,0 +1,117 @@ +use super::common; +use std::sync::Arc; +use tokio_stream::StreamExt; + +/// Generate synthetic PCM audio (440Hz sine wave, 16kHz, 16-bit mono). +fn generate_sine_wave_pcm(sample_rate: i32, duration_seconds: i32, frequency: f64) -> Vec { + let total_samples = (sample_rate * duration_seconds) as usize; + let mut pcm_bytes = vec![0u8; total_samples * 2]; // 16-bit = 2 bytes per sample + + for i in 0..total_samples { + let t = i as f64 / sample_rate as f64; + let sample = + (i16::MAX as f64 * 0.5 * (2.0 * std::f64::consts::PI * frequency * t).sin()) as i16; + pcm_bytes[i * 2] = (sample & 0xFF) as u8; + pcm_bytes[i * 2 + 1] = ((sample >> 8) & 0xFF) as u8; + } + + pcm_bytes +} + +// --- E2E streaming test with synthetic PCM audio --- + +#[tokio::test] +async fn live_streaming_e2e_with_synthetic_pcm_returns_valid_response() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + + // Try to get a nemotron or whisper model for audio streaming + let model = match catalog.get_model("nemotron").await { + Ok(m) => m, + Err(_) => match catalog.get_model(common::WHISPER_MODEL_ALIAS).await { + Ok(m) => m, + Err(_) => { + eprintln!("Skipping E2E test: no audio model available"); + return; + } + }, + }; + + if !model.is_cached().await.unwrap_or(false) { + eprintln!("Skipping E2E test: model not cached"); + return; + } + + model.load().await.expect("model.load() failed"); + + let audio_client = model.create_audio_client(); + let session = audio_client.create_live_transcription_session(); + + // Verify default settings + assert_eq!(session.settings.sample_rate, 16000); + assert_eq!(session.settings.channels, 1); + assert_eq!(session.settings.bits_per_sample, 16); + + if let Err(e) = session.start(None).await { + eprintln!("Skipping E2E test: could not start session: {e}"); + model.unload().await.ok(); + return; + } + + // Start collecting results in background (must start before pushing audio) + let mut stream = session + .get_transcription_stream() + .await + .expect("get_transcription_stream failed"); + + let results = Arc::new(tokio::sync::Mutex::new(Vec::new())); + let stream_error: Arc>> = + Arc::new(tokio::sync::Mutex::new(None)); + let results_clone = Arc::clone(&results); + let error_clone = Arc::clone(&stream_error); + let read_task = tokio::spawn(async move { + while let Some(result) = stream.next().await { + match result { + Ok(r) => results_clone.lock().await.push(r), + Err(e) => { + *error_clone.lock().await = Some(format!("{e}")); + break; + } + } + } + }); + + // Generate ~2 seconds of synthetic PCM audio (440Hz sine wave) + let pcm_bytes = generate_sine_wave_pcm(16000, 2, 440.0); + + // Push audio in chunks (100ms each, matching typical mic callback size) + let chunk_size = 16000 / 10 * 2; // 100ms of 16-bit audio = 3200 bytes + for offset in (0..pcm_bytes.len()).step_by(chunk_size) { + let end = std::cmp::min(offset + chunk_size, pcm_bytes.len()); + session + .append(&pcm_bytes[offset..end], None) + .await + .expect("append failed"); + } + + // Stop session to flush remaining audio and complete the stream + session.stop(None).await.expect("stop failed"); + read_task.await.expect("read task failed"); + + // Verify no stream errors occurred + assert!( + stream_error.lock().await.is_none(), + "Stream produced an error: {:?}", + stream_error.lock().await + ); + + // Verify response attributes — synthetic audio may or may not produce text, + // but the response objects should be properly structured (C#-compatible envelope) + let results = results.lock().await; + for result in results.iter() { + assert!(!result.content.is_empty(), "content must not be empty"); + assert_eq!(result.content[0].text, result.content[0].transcript); + } + + model.unload().await.expect("model.unload() failed"); +} diff --git a/sdk/rust/tests/integration/main.rs b/sdk/rust/tests/integration/main.rs index c63956f3..05576000 100644 --- a/sdk/rust/tests/integration/main.rs +++ b/sdk/rust/tests/integration/main.rs @@ -12,6 +12,7 @@ mod audio_client_test; mod catalog_test; mod chat_client_test; mod embedding_client_test; +mod live_audio_test; mod manager_test; mod model_test; mod web_service_test;