Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion plugins/local-stt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,13 @@ tower-http = { workspace = true, features = ["cors", "trace"] }

backon = { workspace = true }
futures-util = { workspace = true }
ractor = { workspace = true }
reqwest = { workspace = true }
tokio = { workspace = true, features = ["rt", "macros"] }
tokio-util = { workspace = true }
tracing = { workspace = true }

ractor = { workspace = true }
ractor-supervisor = { workspace = true }

port-killer = "0.1.0"
port_check = "0.3.0"
10 changes: 4 additions & 6 deletions plugins/local-stt/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,14 @@ pub enum Error {
HyprFileError(#[from] hypr_file::Error),
#[error(transparent)]
ShellError(#[from] tauri_plugin_shell::Error),
#[error(transparent)]
TauriError(#[from] tauri::Error),
#[error(transparent)]
IoError(#[from] std::io::Error),
#[error("Model not downloaded")]
ModelNotDownloaded,
#[error("Server already running")]
ServerAlreadyRunning,
#[error("Server start failed {0}")]
ServerStartFailed(String),
#[error("Server stop failed {0}")]
ServerStopFailed(String),
#[error("Supervisor not found")]
SupervisorNotFound,
#[error("AM binary not found")]
AmBinaryNotFound,
#[error("AM API key not set")]
Expand Down
270 changes: 133 additions & 137 deletions plugins/local-stt/src/ext.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::{collections::HashMap, future::Future, path::PathBuf, time::Duration};
use std::{collections::HashMap, future::Future, path::PathBuf, sync::Arc};

use ractor::{call_t, registry, Actor, ActorRef};
use tokio::time::sleep;
use ractor::{call_t, registry, ActorRef};
use tokio_util::sync::CancellationToken;

use tauri::{ipc::Channel, Manager, Runtime};
Expand All @@ -12,13 +11,17 @@ use hypr_file::download_file_parallel_cancellable;

use crate::{
model::SupportedSttModel,
server::{external, internal, ServerInfo, ServerStatus, ServerType},
server::{external, internal, supervisor, ServerInfo, ServerStatus, ServerType},
};

pub trait LocalSttPluginExt<R: Runtime> {
fn models_dir(&self) -> PathBuf;
fn list_ggml_backends(&self) -> Vec<hypr_whisper_local::GgmlBackend>;

fn get_supervisor(
&self,
) -> impl Future<Output = Result<supervisor::SupervisorRef, crate::Error>>;

fn start_server(
&self,
model: SupportedSttModel,
Expand Down Expand Up @@ -53,6 +56,15 @@ impl<R: Runtime, T: Manager<R>> LocalSttPluginExt<R> for T {
hypr_whisper_local::list_ggml_backends()
}

async fn get_supervisor(&self) -> Result<supervisor::SupervisorRef, crate::Error> {
let state = self.state::<crate::SharedState>();
let guard = state.lock().await;
guard
.stt_supervisor
.clone()
.ok_or(crate::Error::SupervisorNotFound)
}

async fn is_model_downloaded(&self, model: &SupportedSttModel) -> Result<bool, crate::Error> {
match model {
SupportedSttModel::Am(model) => Ok(model.is_downloaded(self.models_dir())?),
Expand All @@ -77,12 +89,12 @@ impl<R: Runtime, T: Manager<R>> LocalSttPluginExt<R> for T {

#[tracing::instrument(skip_all)]
async fn start_server(&self, model: SupportedSttModel) -> Result<String, crate::Error> {
let t = match &model {
let server_type = match &model {
SupportedSttModel::Am(_) => ServerType::External,
SupportedSttModel::Whisper(_) => ServerType::Internal,
};

let current_info = match t {
let current_info = match server_type {
ServerType::Internal => internal_health().await,
ServerType::External => external_health().await,
};
Expand All @@ -99,164 +111,56 @@ impl<R: Runtime, T: Manager<R>> LocalSttPluginExt<R> for T {
}
}

if matches!(t, ServerType::Internal) && !self.is_model_downloaded(&model).await? {
if matches!(server_type, ServerType::Internal) && !self.is_model_downloaded(&model).await? {
return Err(crate::Error::ModelNotDownloaded);
}

let am_key = if matches!(t, ServerType::External) {
let state = self.state::<crate::SharedState>();
let key = {
let guard = state.lock().await;
guard.am_api_key.clone()
};
let key = key
.filter(|k| !k.is_empty())
.ok_or(crate::Error::AmApiKeyNotSet)?;
Some(key)
} else {
None
};

let cache_dir = self.models_dir();
let data_dir = self.app_handle().path().app_data_dir().unwrap().join("stt");
let supervisor = self.get_supervisor().await?;

self.stop_server(None).await?;
// Need some delay
sleep(Duration::from_millis(300)).await;
supervisor::stop_all_stt_servers(&supervisor)
.await
.map_err(|e| crate::Error::ServerStopFailed(e.to_string()))?;

match t {
match server_type {
ServerType::Internal => {
let cache_dir = self.models_dir();
let whisper_model = match model {
SupportedSttModel::Whisper(m) => m,
_ => {
return Err(crate::Error::UnsupportedModelType);
}
_ => return Err(crate::Error::UnsupportedModelType),
};

let (_server, _) = Actor::spawn(
Some(internal::InternalSTTActor::name()),
internal::InternalSTTActor,
internal::InternalSTTArgs {
model_cache_dir: cache_dir,
model_type: whisper_model,
},
)
.await
.map_err(|e| crate::Error::ServerStartFailed(e.to_string()))?;

internal_health()
.await
.and_then(|info| info.url)
.ok_or_else(|| crate::Error::ServerStartFailed("empty_health".to_string()))
start_internal_server(&supervisor, cache_dir, whisper_model).await
}
ServerType::External => {
let data_dir = self.app_handle().path().app_data_dir().unwrap().join("stt");
let am_model = match model {
SupportedSttModel::Am(m) => m,
_ => {
return Err(crate::Error::UnsupportedModelType);
}
};

let am_key = match am_key {
Some(key) => key,
None => {
return Err(crate::Error::AmApiKeyNotSet);
}
};

let cmd: tauri_plugin_shell::process::Command = {
#[cfg(debug_assertions)]
{
let passthrough_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("../../apps/desktop/src-tauri/resources/passthrough-aarch64-apple-darwin");
let stt_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join(
"../../apps/desktop/src-tauri/resources/stt-aarch64-apple-darwin",
);

if !passthrough_path.exists() || !stt_path.exists() {
return Err(crate::Error::AmBinaryNotFound);
}

self.shell()
.command(passthrough_path)
.current_dir(dirs::home_dir().unwrap())
.arg(stt_path)
.args(["serve", "--any-token", "-v", "-d"])
}

#[cfg(not(debug_assertions))]
self.shell()
.sidecar("stt")?
.current_dir(dirs::home_dir().unwrap())
.args(["serve", "--any-token"])
_ => return Err(crate::Error::UnsupportedModelType),
};

let (_server, _) = Actor::spawn(
Some(external::ExternalSTTActor::name()),
external::ExternalSTTActor,
external::ExternalSTTArgs {
cmd,
api_key: am_key,
model: am_model,
models_dir: data_dir,
},
)
.await
.map_err(|e| crate::Error::ServerStartFailed(e.to_string()))?;

external_health()
.await
.and_then(|info| info.url)
.ok_or_else(|| crate::Error::ServerStartFailed("empty_health".to_string()))
start_external_server(self, &supervisor, data_dir, am_model).await
}
}
}

#[tracing::instrument(skip_all)]
async fn stop_server(&self, server_type: Option<ServerType>) -> Result<bool, crate::Error> {
let mut stopped = false;
let supervisor = self.get_supervisor().await?;

match server_type {
Some(ServerType::External) => {
if let Some(cell) = registry::where_is(external::ExternalSTTActor::name()) {
let actor: ActorRef<external::ExternalSTTMessage> = cell.into();
if let Err(e) = actor.stop_and_wait(None, None).await {
tracing::error!("stop_server: {:?}", e);
} else {
stopped = true;
}
}
}
Some(ServerType::Internal) => {
if let Some(cell) = registry::where_is(internal::InternalSTTActor::name()) {
let actor: ActorRef<internal::InternalSTTMessage> = cell.into();
if let Err(e) = actor.stop_and_wait(None, None).await {
tracing::error!("stop_server: {:?}", e);
} else {
stopped = true;
}
}
Some(t) => {
supervisor::stop_stt_server(&supervisor, t)
.await
.map_err(|e| crate::Error::ServerStopFailed(e.to_string()))?;
Ok(true)
}
None => {
if let Some(cell) = registry::where_is(external::ExternalSTTActor::name()) {
let actor: ActorRef<external::ExternalSTTMessage> = cell.into();
if let Err(e) = actor.stop_and_wait(None, None).await {
tracing::error!("stop_server: {:?}", e);
} else {
stopped = true;
}
}
if let Some(cell) = registry::where_is(internal::InternalSTTActor::name()) {
let actor: ActorRef<internal::InternalSTTMessage> = cell.into();
if let Err(e) = actor.stop_and_wait(None, None).await {
tracing::error!("stop_server: {:?}", e);
} else {
stopped = true;
}
}
supervisor::stop_all_stt_servers(&supervisor)
.await
.map_err(|e| crate::Error::ServerStopFailed(e.to_string()))?;
Ok(true)
}
}

Ok(stopped)
}

#[tracing::instrument(skip_all)]
Expand Down Expand Up @@ -410,6 +314,98 @@ impl<R: Runtime, T: Manager<R>> LocalSttPluginExt<R> for T {
}
}

async fn start_internal_server(
supervisor: &supervisor::SupervisorRef,
cache_dir: PathBuf,
model: hypr_whisper_local_model::WhisperModel,
) -> Result<String, crate::Error> {
supervisor::start_internal_stt(
supervisor,
internal::InternalSTTArgs {
model_cache_dir: cache_dir,
model_type: model,
},
)
.await
.map_err(|e| crate::Error::ServerStartFailed(e.to_string()))?;

internal_health()
.await
.and_then(|info| info.url)
.ok_or_else(|| crate::Error::ServerStartFailed("empty_health".to_string()))
}

async fn start_external_server<R: Runtime, T: Manager<R>>(
manager: &T,
supervisor: &supervisor::SupervisorRef,
data_dir: PathBuf,
model: hypr_am::AmModel,
) -> Result<String, crate::Error> {
let am_key = {
let state = manager.state::<crate::SharedState>();
let key = {
let guard = state.lock().await;
guard.am_api_key.clone()
};

key.filter(|k| !k.is_empty())
.ok_or(crate::Error::AmApiKeyNotSet)?
};

let port = port_check::free_local_port()
.ok_or_else(|| crate::Error::ServerStartFailed("failed_to_find_free_port".to_string()))?;

let app_handle = manager.app_handle().clone();
let cmd_builder = {
#[cfg(debug_assertions)]
{
let passthrough_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("../../apps/desktop/src-tauri/resources/passthrough-aarch64-apple-darwin");
let stt_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("../../apps/desktop/src-tauri/resources/stt-aarch64-apple-darwin");

if !passthrough_path.exists() || !stt_path.exists() {
return Err(crate::Error::AmBinaryNotFound);
}

let passthrough_path = Arc::new(passthrough_path);
let stt_path = Arc::new(stt_path);
external::CommandBuilder::new(move || {
app_handle
.shell()
.command(passthrough_path.as_ref())
.current_dir(dirs::home_dir().unwrap())
.arg(stt_path.as_ref())
.args(["serve", "--any-token", "-v", "-d"])
})
}

#[cfg(not(debug_assertions))]
{
external::CommandBuilder::new(move || {
app_handle
.shell()
.sidecar("stt")
.expect("failed to create sidecar command")
.current_dir(dirs::home_dir().unwrap())
.args(["serve", "--any-token"])
})
}
};

supervisor::start_external_stt(
supervisor,
external::ExternalSTTArgs::new(cmd_builder, am_key, model, data_dir, port),
)
.await
.map_err(|e| crate::Error::ServerStartFailed(e.to_string()))?;

external_health()
.await
.and_then(|info| info.url)
.ok_or_else(|| crate::Error::ServerStartFailed("empty_health".to_string()))
}

async fn internal_health() -> Option<ServerInfo> {
match registry::where_is(internal::InternalSTTActor::name()) {
Some(cell) => {
Expand Down
Loading
Loading