diff --git a/.github/workflows/desktop_cd.yaml b/.github/workflows/desktop_cd.yaml index 1f7329ebb..097d0601e 100644 --- a/.github/workflows/desktop_cd.yaml +++ b/.github/workflows/desktop_cd.yaml @@ -107,14 +107,13 @@ jobs: AWS_ACCESS_KEY_ID: ${{ secrets.CLOUDFLARE_R2_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.CLOUDFLARE_R2_SECRET_ACCESS_KEY }} - if: ${{ matrix.target == 'aarch64-apple-darwin' }} - run: | - chmod +x apps/desktop/src-tauri/binaries/stt-${{ matrix.target }} - ./scripts/sidecar.sh "apps/desktop/${{ env.TAURI_CONF_PATH }}" "binaries/stt" + run: chmod +x ./apps/desktop/src-tauri/binaries/stt-${{ matrix.target }} && ./scripts/sidecar.sh "./apps/desktop/${{ env.TAURI_CONF_PATH }}" "binaries/stt" - run: pnpm -F desktop tauri build --target ${{ matrix.target }} --config ${{ env.TAURI_CONF_PATH }} --verbose env: # https://github.com/tauri-apps/tauri-action/issues/740 CI: false GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + AM_API_KEY: ${{ secrets.AM_API_KEY }} POSTHOG_API_KEY: ${{ secrets.POSTHOG_API_KEY }} SENTRY_DSN: ${{ secrets.SENTRY_DSN }} KEYGEN_ACCOUNT_ID: ${{ secrets.KEYGEN_ACCOUNT_ID }} diff --git a/Cargo.lock b/Cargo.lock index 5048e2055..4687c7211 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -184,6 +184,17 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "am" +version = "0.1.0" +dependencies = [ + "reqwest 0.12.22", + "serde", + "serde_json", + "thiserror 2.0.12", + "tokio", +] + [[package]] name = "analytics" version = "0.1.0" @@ -4526,6 +4537,7 @@ dependencies = [ "testcontainers-modules", "thiserror 2.0.12", "tokio", + "tracing", "wiremock", ] @@ -8963,6 +8975,19 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "moonshine" +version = "0.1.0" +dependencies = [ + "data", + "dirs 6.0.0", + "onnx", + "owhisper-config", + "rodio", + "thiserror 2.0.12", + "tokenizers", +] + [[package]] name = "moshi" version = "0.6.3" @@ -10195,6 +10220,7 @@ dependencies = [ "serde_json", "specta", "strum 0.26.3", + "uuid", ] [[package]] @@ -14605,6 +14631,7 @@ dependencies = [ name = "tauri-plugin-local-stt" version = "0.1.0" dependencies = [ + "am", "audio-utils", "axum 0.8.4", "axum-extra", @@ -16193,19 +16220,14 @@ dependencies = [ "axum 0.8.4", "bytes", "chunker", - "data", - "dirs 6.0.0", "futures-util", - "kalosm-sound", - "onnx", + "moonshine", "owhisper-config", "owhisper-interface", - "rodio", "serde", "serde_json", "serde_qs 1.0.0-rc.3", "thiserror 2.0.12", - "tokenizers", "tokio", "tokio-stream", "tokio-tungstenite 0.26.2", diff --git a/Cargo.toml b/Cargo.toml index fc895bb87..fd84a4e3f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ rust-version = "1.88.0" [workspace.dependencies] hypr-aec = { path = "crates/aec", package = "aec" } hypr-agc = { path = "crates/agc", package = "agc" } +hypr-am = { path = "crates/am", package = "am" } hypr-analytics = { path = "crates/analytics", package = "analytics" } hypr-audio = { path = "crates/audio", package = "audio" } hypr-audio-utils = { path = "crates/audio-utils", package = "audio-utils" } @@ -47,6 +48,7 @@ hypr-kyutai = { path = "crates/kyutai", package = "kyutai" } hypr-language = { path = "crates/language", package = "language" } hypr-llama = { path = "crates/llama", package = "llama" } hypr-loops = { path = "crates/loops", package = "loops" } +hypr-moonshine = { path = "crates/moonshine", package = "moonshine" } hypr-nango = { path = "crates/nango", package = "nango" } hypr-network = { path = "crates/network", package = "network" } hypr-notification = { path = "crates/notification", package = "notification" } diff --git a/apps/desktop/src-tauri/src/ext.rs b/apps/desktop/src-tauri/src/ext.rs index 27ee1ae0c..27682dfdb 100644 --- a/apps/desktop/src-tauri/src/ext.rs +++ b/apps/desktop/src-tauri/src/ext.rs @@ -42,7 +42,7 @@ impl> AppExt for T { .unwrap_or(hypr_whisper_local_model::WhisperModel::QuantizedBaseEn); if let Ok(true) = self.is_model_downloaded(¤t_model).await { - if let Err(e) = self.start_server().await { + if let Err(e) = self.start_server(None).await { tracing::error!("start_local_stt_server: {}", e); } } diff --git a/apps/desktop/src/components/settings/components/ai/stt-view.tsx b/apps/desktop/src/components/settings/components/ai/stt-view.tsx index 9b0924197..28139b775 100644 --- a/apps/desktop/src/components/settings/components/ai/stt-view.tsx +++ b/apps/desktop/src/components/settings/components/ai/stt-view.tsx @@ -6,12 +6,8 @@ import { useEffect } from "react"; import { useForm } from "react-hook-form"; import { z } from "zod"; -// Add these imports for file operations -// import { message } from "@tauri-apps/plugin-dialog"; -// import { writeFile } from "@tauri-apps/plugin-fs"; - import { commands as dbCommands } from "@hypr/plugin-db"; -import { commands as localSttCommands, SupportedModel } from "@hypr/plugin-local-stt"; +import { commands as localSttCommands, type WhisperModel } from "@hypr/plugin-local-stt"; import { Button } from "@hypr/ui/components/ui/button"; import { Form, @@ -28,7 +24,7 @@ import { cn } from "@hypr/ui/lib/utils"; import { WERPerformanceModal } from "../wer-modal"; import { SharedSTTProps } from "./shared"; -export const sttModelMetadata: Record { if (model.downloaded) { setSelectedSTTModel(model.key); - localSttCommands.setCurrentModel(model.key as SupportedModel); - localSttCommands.restartServer(); + localSttCommands.setCurrentModel(model.key as WhisperModel); + localSttCommands.stopServer(null); + localSttCommands.startServer(null); } }} > diff --git a/apps/desktop/src/components/toast/model-select.tsx b/apps/desktop/src/components/toast/model-select.tsx index 432ee9956..daee646ba 100644 --- a/apps/desktop/src/components/toast/model-select.tsx +++ b/apps/desktop/src/components/toast/model-select.tsx @@ -1,13 +1,13 @@ import type { LinkProps } from "@tanstack/react-router"; -import { commands as localSttCommands, SupportedModel } from "@hypr/plugin-local-stt"; +import { commands as localSttCommands, type WhisperModel } from "@hypr/plugin-local-stt"; import { commands as windowsCommands } from "@hypr/plugin-windows"; import { Button } from "@hypr/ui/components/ui/button"; import { sonnerToast, toast } from "@hypr/ui/components/ui/toast"; export async function showModelSelectToast(language: string) { const currentModel = await localSttCommands.getCurrentModel(); - const englishModels: SupportedModel[] = ["QuantizedTinyEn", "QuantizedBaseEn", "QuantizedSmallEn"]; + const englishModels: WhisperModel[] = ["QuantizedTinyEn", "QuantizedBaseEn", "QuantizedSmallEn"]; if (language === "en" || !englishModels.includes(currentModel)) { return; diff --git a/apps/desktop/src/components/toast/shared.tsx b/apps/desktop/src/components/toast/shared.tsx index 277ce99cb..17625c9b0 100644 --- a/apps/desktop/src/components/toast/shared.tsx +++ b/apps/desktop/src/components/toast/shared.tsx @@ -3,7 +3,7 @@ import { Channel } from "@tauri-apps/api/core"; import { useEffect, useState } from "react"; import { commands as localLlmCommands, SupportedModel as SupportedModelLLM } from "@hypr/plugin-local-llm"; -import { commands as localSttCommands, SupportedModel } from "@hypr/plugin-local-stt"; +import { commands as localSttCommands, type WhisperModel } from "@hypr/plugin-local-stt"; import { commands as windowsCommands } from "@hypr/plugin-windows"; import { Button } from "@hypr/ui/components/ui/button"; import { Progress } from "@hypr/ui/components/ui/progress"; @@ -52,7 +52,7 @@ export const DownloadProgress = ({ ); }; -export function showSttModelDownloadToast(model: SupportedModel, onComplete?: () => void, queryClient?: QueryClient) { +export function showSttModelDownloadToast(model: WhisperModel, onComplete?: () => void, queryClient?: QueryClient) { const sttChannel = new Channel(); localSttCommands.downloadModel(model, sttChannel); @@ -77,7 +77,7 @@ export function showSttModelDownloadToast(model: SupportedModel, onComplete?: () onComplete={() => { sonnerToast.dismiss(id); localSttCommands.setCurrentModel(model); - localSttCommands.startServer(); + localSttCommands.startServer(null); if (onComplete) { onComplete(); } diff --git a/apps/desktop/src/components/welcome-modal/download-progress-view.tsx b/apps/desktop/src/components/welcome-modal/download-progress-view.tsx index f56f2e02d..c84ae8f25 100644 --- a/apps/desktop/src/components/welcome-modal/download-progress-view.tsx +++ b/apps/desktop/src/components/welcome-modal/download-progress-view.tsx @@ -4,7 +4,7 @@ import { BrainIcon, CheckCircle2Icon, MicIcon } from "lucide-react"; import { useEffect, useState } from "react"; import { commands as localLlmCommands } from "@hypr/plugin-local-llm"; -import { commands as localSttCommands, SupportedModel } from "@hypr/plugin-local-stt"; +import { commands as localSttCommands, type WhisperModel } from "@hypr/plugin-local-stt"; import { Progress } from "@hypr/ui/components/ui/progress"; import PushableButton from "@hypr/ui/components/ui/pushable-button"; import { cn } from "@hypr/ui/lib/utils"; @@ -18,7 +18,7 @@ interface ModelDownloadProgress { } interface DownloadProgressViewProps { - selectedSttModel: SupportedModel; + selectedSttModel: WhisperModel; llmSelection: "hyprllm" | "byom" | null; onContinue: () => void; } @@ -174,7 +174,7 @@ export const DownloadProgressView = ({ if (sttDownload.completed) { try { await localSttCommands.setCurrentModel(selectedSttModel); - await localSttCommands.startServer(); + await localSttCommands.startServer(null); } catch (error) { console.error("Error setting up STT:", error); } diff --git a/apps/desktop/src/components/welcome-modal/index.tsx b/apps/desktop/src/components/welcome-modal/index.tsx index a59d0d960..82d0726ce 100644 --- a/apps/desktop/src/components/welcome-modal/index.tsx +++ b/apps/desktop/src/components/welcome-modal/index.tsx @@ -1,29 +1,28 @@ import { useMutation, useQueryClient } from "@tanstack/react-query"; import { useNavigate } from "@tanstack/react-router"; import { message } from "@tauri-apps/plugin-dialog"; -import { ArrowLeft } from "lucide-react"; // Add this import +import { ArrowLeft } from "lucide-react"; import { useEffect, useState } from "react"; import { showLlmModelDownloadToast, showSttModelDownloadToast } from "@/components/toast/shared"; import { commands } from "@/types"; import { commands as authCommands, events } from "@hypr/plugin-auth"; -import { commands as localSttCommands, SupportedModel } from "@hypr/plugin-local-stt"; +import { commands as localSttCommands, type WhisperModel } from "@hypr/plugin-local-stt"; import { commands as sfxCommands } from "@hypr/plugin-sfx"; import { Modal, ModalBody } from "@hypr/ui/components/ui/modal"; import { Particles } from "@hypr/ui/components/ui/particles"; import { ConfigureEndpointConfig } from "../settings/components/ai/shared"; +import { useHypr } from "@/contexts"; import { zodResolver } from "@hookform/resolvers/zod"; +import { commands as analyticsCommands } from "@hypr/plugin-analytics"; import { commands as connectorCommands } from "@hypr/plugin-connector"; import { commands as dbCommands } from "@hypr/plugin-db"; import { commands as localLlmCommands } from "@hypr/plugin-local-llm"; +import { Trans } from "@lingui/react/macro"; import { useForm } from "react-hook-form"; import { z } from "zod"; import { AudioPermissionsView } from "./audio-permissions-view"; -// import { CalendarPermissionsView } from "./calendar-permissions-view"; -import { useHypr } from "@/contexts"; -import { commands as analyticsCommands } from "@hypr/plugin-analytics"; -import { Trans } from "@lingui/react/macro"; import { CustomEndpointView } from "./custom-endpoint-view"; import { DownloadProgressView } from "./download-progress-view"; import { LanguageSelectionView } from "./language-selection-view"; @@ -72,13 +71,13 @@ export function WelcomeModal({ isOpen, onClose }: WelcomeModalProps) { | "custom-endpoint" | "language-selection" >("welcome"); - const [selectedSttModel, setSelectedSttModel] = useState("QuantizedSmall"); + const [selectedSttModel, setSelectedSttModel] = useState("QuantizedSmall"); const [wentThroughDownloads, setWentThroughDownloads] = useState(false); const [llmSelection, setLlmSelection] = useState<"hyprllm" | "byom" | null>(null); const [cameFromLlmSelection, setCameFromLlmSelection] = useState(false); const selectSTTModel = useMutation({ - mutationFn: (model: SupportedModel) => localSttCommands.setCurrentModel(model), + mutationFn: (model: WhisperModel) => localSttCommands.setCurrentModel(model), }); const openaiForm = useForm<{ api_key: string; model: string }>({ @@ -253,7 +252,7 @@ export function WelcomeModal({ isOpen, onClose }: WelcomeModalProps) { setCurrentStep("audio-permissions"); }; - const handleModelSelected = (model: SupportedModel) => { + const handleModelSelected = (model: WhisperModel) => { selectSTTModel.mutate(model); setSelectedSttModel(model); sessionStorage.setItem("model-download-toast-dismissed", "true"); @@ -305,13 +304,13 @@ export function WelcomeModal({ isOpen, onClose }: WelcomeModalProps) { useEffect(() => { if (!isOpen && wentThroughDownloads) { - localSttCommands.startServer(); + localSttCommands.startServer(null); localLlmCommands.startServer(); const checkAndShowToasts = async () => { try { - const sttModelExists = await localSttCommands.isModelDownloaded(selectedSttModel as SupportedModel); + const sttModelExists = await localSttCommands.isModelDownloaded(selectedSttModel as WhisperModel); if (!sttModelExists) { showSttModelDownloadToast(selectedSttModel, undefined, queryClient); diff --git a/apps/desktop/src/components/welcome-modal/model-selection-view.tsx b/apps/desktop/src/components/welcome-modal/model-selection-view.tsx index a38c69c94..b8d01ce6a 100644 --- a/apps/desktop/src/components/welcome-modal/model-selection-view.tsx +++ b/apps/desktop/src/components/welcome-modal/model-selection-view.tsx @@ -5,7 +5,7 @@ import React, { useState } from "react"; import { Card, CardContent } from "@hypr/ui/components/ui/card"; -import { SupportedModel } from "@hypr/plugin-local-stt"; +import { type WhisperModel } from "@hypr/plugin-local-stt"; import { commands as localSttCommands } from "@hypr/plugin-local-stt"; import PushableButton from "@hypr/ui/components/ui/pushable-button"; import { cn } from "@hypr/ui/lib/utils"; @@ -45,9 +45,9 @@ const RatingDisplay = ( export const ModelSelectionView = ({ onContinue, }: { - onContinue: (model: SupportedModel) => void; + onContinue: (model: WhisperModel) => void; }) => { - const [selectedModel, setSelectedModel] = useState("QuantizedSmall"); + const [selectedModel, setSelectedModel] = useState("QuantizedSmall"); const supportedSTTModels = useQuery({ queryKey: ["local-stt", "supported-models"], @@ -82,7 +82,7 @@ export const ModelSelectionView = ({ }) ?.map(modelInfo => { const model = modelInfo.model; - const metadata = sttModelMetadata[model as SupportedModel]; + const metadata = sttModelMetadata[model as WhisperModel]; if (!metadata) { return null; } @@ -99,7 +99,7 @@ export const ModelSelectionView = ({ ? "ring-2 ring-blue-500 border-blue-500 bg-blue-50" : "hover:border-gray-400", )} - onClick={() => setSelectedModel(model as SupportedModel)} + onClick={() => setSelectedModel(model as WhisperModel)} >
diff --git a/apps/desktop/src/locales/en/messages.po b/apps/desktop/src/locales/en/messages.po index 75c4b12f8..7888bf4ca 100644 --- a/apps/desktop/src/locales/en/messages.po +++ b/apps/desktop/src/locales/en/messages.po @@ -444,8 +444,8 @@ msgstr "Audio Permissions" msgid "Autonomy Selector" msgstr "Autonomy Selector" -#: src/components/welcome-modal/index.tsx:351 -#: src/components/welcome-modal/index.tsx:362 +#: src/components/welcome-modal/index.tsx:350 +#: src/components/welcome-modal/index.tsx:361 msgid "Back" msgstr "Back" @@ -1189,7 +1189,7 @@ msgstr "Pause" msgid "people" msgstr "people" -#: src/components/settings/components/ai/stt-view.tsx:332 +#: src/components/settings/components/ai/stt-view.tsx:328 msgid "Performance difference between languages" msgstr "Performance difference between languages" @@ -1510,7 +1510,7 @@ msgstr "Toggle left sidebar" msgid "Toggle widget panel" msgstr "Toggle widget panel" -#: src/components/settings/components/ai/stt-view.tsx:323 +#: src/components/settings/components/ai/stt-view.tsx:319 msgid "Transcribing" msgstr "Transcribing" diff --git a/apps/desktop/src/locales/ko/messages.po b/apps/desktop/src/locales/ko/messages.po index 32249a4ac..fd1ce8ccb 100644 --- a/apps/desktop/src/locales/ko/messages.po +++ b/apps/desktop/src/locales/ko/messages.po @@ -444,8 +444,8 @@ msgstr "" msgid "Autonomy Selector" msgstr "" -#: src/components/welcome-modal/index.tsx:351 -#: src/components/welcome-modal/index.tsx:362 +#: src/components/welcome-modal/index.tsx:350 +#: src/components/welcome-modal/index.tsx:361 msgid "Back" msgstr "" @@ -1189,7 +1189,7 @@ msgstr "" msgid "people" msgstr "" -#: src/components/settings/components/ai/stt-view.tsx:332 +#: src/components/settings/components/ai/stt-view.tsx:328 msgid "Performance difference between languages" msgstr "" @@ -1510,7 +1510,7 @@ msgstr "" msgid "Toggle widget panel" msgstr "" -#: src/components/settings/components/ai/stt-view.tsx:323 +#: src/components/settings/components/ai/stt-view.tsx:319 msgid "Transcribing" msgstr "" diff --git a/apps/desktop/src/routes/app.tsx b/apps/desktop/src/routes/app.tsx index 9c805f608..9a9828422 100644 --- a/apps/desktop/src/routes/app.tsx +++ b/apps/desktop/src/routes/app.tsx @@ -166,7 +166,11 @@ function RestartSTT() { const sttPath = await localSttCommands.modelsDir(); return watch(sttPath, (_event) => { - localSttCommands.restartServer(); + localSttCommands.stopServer("internal").then((stopped) => { + if (stopped) { + localSttCommands.startServer("internal"); + } + }); }, { delayMs: 1000 }); }; diff --git a/crates/am/Cargo.toml b/crates/am/Cargo.toml new file mode 100644 index 000000000..787e8c80b --- /dev/null +++ b/crates/am/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "am" +version = "0.1.0" +edition = "2021" + +[dev-dependencies] +tokio = { workspace = true, features = ["rt", "macros"] } + +[dependencies] +reqwest = { workspace = true, features = ["json"] } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +thiserror = { workspace = true } diff --git a/crates/am/src/client.rs b/crates/am/src/client.rs new file mode 100644 index 000000000..06f217eca --- /dev/null +++ b/crates/am/src/client.rs @@ -0,0 +1,205 @@ +use crate::{ + ComputeUnits, Error, ErrorResponse, GenericResponse, InitRequest, InitResponse, ServerStatus, +}; +use reqwest::{Response, StatusCode}; + +#[derive(Clone)] +pub struct AmClient { + client: reqwest::Client, + base_url: String, +} + +impl AmClient { + pub fn new(base_url: impl Into) -> Self { + Self { + client: reqwest::Client::new(), + base_url: base_url.into(), + } + } + + pub fn with_client(client: reqwest::Client, base_url: impl Into) -> Self { + Self { + client, + base_url: base_url.into(), + } + } + + pub async fn status(&self) -> Result { + let url = format!("{}/v1/status", self.base_url); + let response = self.client.get(&url).send().await?; + + if response.status().is_success() { + Ok(response.json().await?) + } else { + Err(self.handle_error_response(response).await) + } + } + + pub async fn wait_for_ready( + &self, + max_wait_time: Option, + poll_interval: Option, + ) -> Result { + let url = format!("{}/v1/waitForReady", self.base_url); + let mut request = self.client.get(&url); + + if let Some(max_wait) = max_wait_time { + request = request.query(&[("maxWaitTime", max_wait)]); + } + + if let Some(interval) = poll_interval { + request = request.query(&[("pollInterval", interval)]); + } + + let response = request.send().await?; + + match response.status() { + StatusCode::OK => Ok(response.json().await?), + StatusCode::BAD_REQUEST | StatusCode::REQUEST_TIMEOUT => { + Err(self.handle_error_response(response).await) + } + _ => Err(Error::UnexpectedResponse), + } + } + + pub async fn init(&self, request: InitRequest) -> Result { + if !request.api_key.starts_with("ax_") { + return Err(Error::InvalidApiKey); + } + + let url = format!("{}/v1/init", self.base_url); + let response = self.client.post(&url).json(&request).send().await?; + + match response.status() { + StatusCode::OK => Ok(response.json().await?), + StatusCode::BAD_REQUEST | StatusCode::CONFLICT => { + Err(self.handle_error_response(response).await) + } + _ => Err(Error::UnexpectedResponse), + } + } + + pub async fn reset(&self) -> Result { + let url = format!("{}/v1/reset", self.base_url); + let response = self.client.post(&url).send().await?; + + if response.status().is_success() { + Ok(response.json().await?) + } else { + Err(self.handle_error_response(response).await) + } + } + + pub async fn unload(&self) -> Result { + let url = format!("{}/v1/unload", self.base_url); + let response = self.client.post(&url).send().await?; + + match response.status() { + StatusCode::OK => Ok(response.json().await?), + StatusCode::BAD_REQUEST => Err(self.handle_error_response(response).await), + _ => Err(Error::UnexpectedResponse), + } + } + + pub async fn shutdown(&self) -> Result { + let url = format!("{}/v1/shutdown", self.base_url); + let response = self.client.post(&url).send().await?; + + if response.status().is_success() { + Ok(response.json().await?) + } else { + Err(self.handle_error_response(response).await) + } + } + + async fn handle_error_response(&self, response: Response) -> Error { + if let Ok(error_response) = response.json::().await { + Error::ServerError { + status: error_response.status, + message: error_response.message, + } + } else { + Error::UnexpectedResponse + } + } +} + +impl InitRequest { + pub fn new(api_key: impl Into) -> Self { + Self { + api_key: api_key.into(), + model: None, + model_token: None, + download_base: None, + model_repo: None, + model_folder: None, + tokenizer_folder: None, + fast_load: None, + fast_load_encoder_compute_units: None, + fast_load_decoder_compute_units: None, + model_vad: None, + verbose: None, + } + } + + pub fn with_model(mut self, model: impl Into) -> Self { + self.model = Some(model.into()); + self + } + + pub fn with_model_token(mut self, token: impl Into) -> Self { + self.model_token = Some(token.into()); + self + } + + pub fn with_download_base(mut self, download_base: impl Into) -> Self { + self.download_base = Some(download_base.into()); + self + } + + pub fn with_model_repo(mut self, repo: impl Into) -> Self { + self.model_repo = Some(repo.into()); + self + } + + pub fn with_model_folder(mut self, folder: impl Into) -> Self { + self.model_folder = Some(folder.into()); + self + } + + pub fn with_tokenizer_folder(mut self, folder: impl Into) -> Self { + self.tokenizer_folder = Some(folder.into()); + self + } + + pub fn with_fast_load(mut self, fast_load: bool) -> Self { + self.fast_load = Some(fast_load); + self + } + + pub fn with_encoder_compute_units(mut self, units: ComputeUnits) -> Self { + self.fast_load_encoder_compute_units = Some(units); + self + } + + pub fn with_decoder_compute_units(mut self, units: ComputeUnits) -> Self { + self.fast_load_decoder_compute_units = Some(units); + self + } + + pub fn with_model_vad(mut self, vad: bool) -> Self { + self.model_vad = Some(vad); + self + } + + pub fn with_verbose(mut self, verbose: bool) -> Self { + self.verbose = Some(verbose); + self + } +} + +impl Default for AmClient { + fn default() -> Self { + Self::new("http://localhost:50060") + } +} diff --git a/crates/am/src/error.rs b/crates/am/src/error.rs new file mode 100644 index 000000000..fca7b98ad --- /dev/null +++ b/crates/am/src/error.rs @@ -0,0 +1,14 @@ +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error(transparent)] + Request(#[from] reqwest::Error), + + #[error("Server returned error: {status} - {message}")] + ServerError { status: String, message: String }, + + #[error("Invalid API key format: must start with 'ax_'")] + InvalidApiKey, + + #[error("Unexpected response from server")] + UnexpectedResponse, +} diff --git a/crates/am/src/lib.rs b/crates/am/src/lib.rs new file mode 100644 index 000000000..654b52e35 --- /dev/null +++ b/crates/am/src/lib.rs @@ -0,0 +1,20 @@ +mod client; +mod error; +mod types; + +pub use client::*; +pub use error::*; +pub use types::*; + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_client_creation() { + let client = AmClient::new("http://localhost:50060"); + let status = client.status().await; + println!("{:?}", status); + assert!(true); + } +} diff --git a/crates/am/src/types.rs b/crates/am/src/types.rs new file mode 100644 index 000000000..40f6848c6 --- /dev/null +++ b/crates/am/src/types.rs @@ -0,0 +1,81 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ServerStatus { + pub status: ServerStatusType, + pub model: String, + pub version: String, + pub model_state: String, + pub verbose: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum ServerStatusType { + Ready, + Initializing, + Uninitialized, + Unloaded, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InitRequest { + pub api_key: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model_token: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub download_base: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model_repo: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model_folder: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tokenizer_folder: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub fast_load: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub fast_load_encoder_compute_units: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub fast_load_decoder_compute_units: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model_vad: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub verbose: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ComputeUnits { + Cpu, + #[serde(rename = "cpuandgpu")] + CpuAndGpu, + #[serde(rename = "cpuandneuralengine")] + CpuAndNeuralEngine, + All, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InitResponse { + pub status: String, + pub message: String, + pub model: String, + pub verbose: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GenericResponse { + pub status: String, + pub message: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ErrorResponse { + pub status: String, + pub message: String, +} diff --git a/crates/file/Cargo.toml b/crates/file/Cargo.toml index 244d23e39..934856242 100644 --- a/crates/file/Cargo.toml +++ b/crates/file/Cargo.toml @@ -11,10 +11,11 @@ thiserror = { workspace = true } futures-util = { workspace = true } reqwest = { workspace = true, features = ["multipart", "stream"] } tokio = { workspace = true, features = ["rt", "macros", "fs"] } +tracing = { workspace = true } [dev-dependencies] dirs = { workspace = true } -hypr-s3 = { path = "../../crates/s3", package = "s3" } +hypr-s3 = { workspace = true } tempfile = { workspace = true } testcontainers-modules = { workspace = true, features = ["minio"] } wiremock = "0.5" diff --git a/crates/file/src/lib.rs b/crates/file/src/lib.rs index 4c2004116..2714ad1e7 100644 --- a/crates/file/src/lib.rs +++ b/crates/file/src/lib.rs @@ -151,8 +151,7 @@ pub async fn download_file_parallel( } let head_response = get_client().head(url.clone()).send().await?; - let total_size = get_content_length_from_headers(&head_response) - .ok_or_else(|| OtherError("Content-Length header missing".to_string()))?; + let total_size = get_content_length_from_headers(&head_response); let supports_ranges = head_response .headers() @@ -161,13 +160,15 @@ pub async fn download_file_parallel( .unwrap_or("") == "bytes"; - if !supports_ranges || total_size <= DEFAULT_CHUNK_SIZE { + if !supports_ranges || total_size.unwrap_or(0) <= DEFAULT_CHUNK_SIZE { return download_file_with_callback(url, output_path, move |progress| { progress_callback(progress) }) .await; } + let total_size = total_size.unwrap(); + let existing_size = if output_path.as_ref().exists() { file_size(&output_path)? } else { diff --git a/crates/moonshine/Cargo.toml b/crates/moonshine/Cargo.toml new file mode 100644 index 000000000..6910016d2 --- /dev/null +++ b/crates/moonshine/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "moonshine" +version = "0.1.0" +edition = "2021" + +[features] +default = [] +cuda = ["hypr-onnx/cuda"] +coreml = ["hypr-onnx/coreml"] +directml = ["hypr-onnx/directml"] + +[dependencies] +hypr-onnx = { workspace = true } +owhisper-config = { workspace = true } + +thiserror = { workspace = true } +tokenizers = { workspace = true } + +[dev-dependencies] +dirs = { workspace = true } +hypr-data = { workspace = true } +rodio = { workspace = true } diff --git a/crates/moonshine/src/error.rs b/crates/moonshine/src/error.rs new file mode 100644 index 000000000..f9adb5018 --- /dev/null +++ b/crates/moonshine/src/error.rs @@ -0,0 +1,23 @@ +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error(transparent)] + Io(#[from] std::io::Error), + + #[error(transparent)] + HyprOnnx(#[from] hypr_onnx::Error), + + #[error(transparent)] + Ort(#[from] hypr_onnx::ort::Error), + + #[error("invalid model name: {0}")] + InvalidModelName(String), + + #[error("shape error: {0}")] + Shape(String), + + #[error("tokenizer load error: {0}")] + TokenizerLoad(String), + + #[error("other: {0}")] + Other(String), +} diff --git a/crates/moonshine/src/lib.rs b/crates/moonshine/src/lib.rs new file mode 100644 index 000000000..8857ed443 --- /dev/null +++ b/crates/moonshine/src/lib.rs @@ -0,0 +1,5 @@ +mod error; +mod model; + +pub use error::*; +pub use model::*; diff --git a/crates/transcribe-moonshine/src/model.rs b/crates/moonshine/src/model.rs similarity index 100% rename from crates/transcribe-moonshine/src/model.rs rename to crates/moonshine/src/model.rs diff --git a/crates/transcribe-deepgram/src/lib.rs b/crates/transcribe-deepgram/src/lib.rs index 42ba7b027..835f3fdd2 100644 --- a/crates/transcribe-deepgram/src/lib.rs +++ b/crates/transcribe-deepgram/src/lib.rs @@ -6,8 +6,6 @@ pub use service::*; #[cfg(test)] mod tests { use super::*; - - use futures_util::StreamExt; use hypr_audio_utils::AudioFormatExt; #[tokio::test] @@ -44,16 +42,6 @@ mod tests { let stream = client.from_realtime_audio(audio).await.unwrap(); futures_util::pin_mut!(stream); - while let Some(result) = stream.next().await { - let owhisper_interface::ListenOutputChunk { words, .. } = result; - let text = words - .iter() - .map(|w| w.text.clone()) - .collect::>() - .join(" "); - println!("- {}", text); - } - server_handle.abort(); Ok(()) } diff --git a/crates/transcribe-moonshine/Cargo.toml b/crates/transcribe-moonshine/Cargo.toml index a1f5dd8f7..8e3529a7c 100644 --- a/crates/transcribe-moonshine/Cargo.toml +++ b/crates/transcribe-moonshine/Cargo.toml @@ -5,27 +5,19 @@ edition = "2021" [features] default = [] -cuda = ["hypr-onnx/cuda"] -coreml = ["hypr-onnx/coreml"] -directml = ["hypr-onnx/directml"] - -[dev-dependencies] -dirs = { workspace = true } -hypr-data = { workspace = true } -rodio = { workspace = true } +cuda = ["hypr-moonshine/cuda"] +coreml = ["hypr-moonshine/coreml"] +directml = ["hypr-moonshine/directml"] [dependencies] +hypr-moonshine = { workspace = true } owhisper-config = { workspace = true } owhisper-interface = { workspace = true } hypr-audio-utils = { workspace = true } hypr-chunker = { workspace = true } -hypr-onnx = { workspace = true } hypr-ws-utils = { workspace = true } -kalosm-sound = { workspace = true } -tokenizers = { workspace = true } - serde = { workspace = true, features = ["derive"] } serde_json = { workspace = true } serde_qs = { workspace = true } diff --git a/crates/transcribe-moonshine/src/error.rs b/crates/transcribe-moonshine/src/error.rs index f9adb5018..6bd7f82f2 100644 --- a/crates/transcribe-moonshine/src/error.rs +++ b/crates/transcribe-moonshine/src/error.rs @@ -4,10 +4,7 @@ pub enum Error { Io(#[from] std::io::Error), #[error(transparent)] - HyprOnnx(#[from] hypr_onnx::Error), - - #[error(transparent)] - Ort(#[from] hypr_onnx::ort::Error), + Moonshine(#[from] hypr_moonshine::Error), #[error("invalid model name: {0}")] InvalidModelName(String), diff --git a/crates/transcribe-moonshine/src/lib.rs b/crates/transcribe-moonshine/src/lib.rs index 74fd05ba5..0688ee59d 100644 --- a/crates/transcribe-moonshine/src/lib.rs +++ b/crates/transcribe-moonshine/src/lib.rs @@ -1,7 +1,5 @@ mod error; -mod model; mod service; pub use error::*; -pub use model::*; pub use service::*; diff --git a/crates/transcribe-moonshine/src/service/streaming.rs b/crates/transcribe-moonshine/src/service/streaming.rs index 569f9fd4e..e6849e23b 100644 --- a/crates/transcribe-moonshine/src/service/streaming.rs +++ b/crates/transcribe-moonshine/src/service/streaming.rs @@ -18,10 +18,10 @@ use futures_util::{SinkExt, StreamExt}; use tower::Service; use hypr_chunker::VadExt; -use owhisper_interface::{ListenOutputChunk, ListenParams, Word2}; +use hypr_moonshine::MoonshineOnnxModel; -use crate::MoonshineOnnxModel; use owhisper_config::MoonshineModelSize; +use owhisper_interface::{Alternatives, Channel, ListenParams, Metadata, StreamResponse, Word}; #[derive(Clone)] pub struct TranscribeService { @@ -205,10 +205,10 @@ async fn handle_dual_channel( async fn process_transcription_stream( mut ws_sender: futures_util::stream::SplitSink, - mut stream: Pin + Send>>, + mut stream: Pin + Send>>, ) { - while let Some(chunk) = stream.next().await { - let msg = Message::Text(serde_json::to_string(&chunk).unwrap().into()); + while let Some(response) = stream.next().await { + let msg = Message::Text(serde_json::to_string(&response).unwrap().into()); if let Err(e) = ws_sender.send(msg).await { tracing::warn!("websocket_send_error: {}", e); break; @@ -222,7 +222,7 @@ fn process_vad_stream( stream: S, model: Arc>, source_name: &str, -) -> impl futures_util::Stream +) -> impl futures_util::Stream where S: futures_util::Stream>, E: std::fmt::Display, @@ -252,32 +252,50 @@ where model_guard.transcribe(chunk.samples).unwrap() }; - let speaker = match source_name.as_str() { - "mic" => { - Some(owhisper_interface::SpeakerIdentity::Unassigned { index: 0 }) - } - "speaker" => { - Some(owhisper_interface::SpeakerIdentity::Unassigned { index: 1 }) - } - _ => None, + let (speaker, channel_index) = match source_name.as_str() { + "mic" => (Some(0), vec![0]), + "speaker" => (Some(1), vec![1]), + _ => (None, vec![0]), }; - let data = ListenOutputChunk { - meta: None, - words: text - .split_whitespace() - .filter(|w| !w.is_empty()) - .map(|w| Word2 { - text: w.trim().to_string(), - speaker: speaker.clone(), - start_ms: None, - end_ms: None, - confidence: None, - }) - .collect(), + let start_f64 = 0.0; + let duration_f64 = 0.0; + let confidence = 1.0; + + let words: Vec = text + .split_whitespace() + .filter(|w| !w.is_empty()) + .map(|w| Word { + word: w.to_string(), + start: start_f64, + end: start_f64 + duration_f64, + confidence, + speaker: speaker.clone(), + punctuated_word: None, + language: None, + }) + .collect(); + + let response = StreamResponse::TranscriptResponse { + type_field: "Results".to_string(), + start: start_f64, + duration: duration_f64, + is_final: true, + speech_final: true, + from_finalize: false, + channel: Channel { + alternatives: vec![Alternatives { + transcript: text.clone(), + languages: vec![], + words, + confidence, + }], + }, + metadata: Metadata::default(), + channel_index, }; - Some(data) + Some(response) } } } diff --git a/crates/transcribe-whisper-local/src/lib.rs b/crates/transcribe-whisper-local/src/lib.rs index 2f22e964b..541d40cbe 100644 --- a/crates/transcribe-whisper-local/src/lib.rs +++ b/crates/transcribe-whisper-local/src/lib.rs @@ -8,7 +8,6 @@ pub use service::*; // cargo test -p transcribe-whisper-local test_service -- --nocapture mod tests { use super::*; - use futures_util::StreamExt; use hypr_audio_utils::AudioFormatExt; #[tokio::test] @@ -45,16 +44,6 @@ mod tests { let stream = client.from_realtime_audio(audio).await.unwrap(); futures_util::pin_mut!(stream); - while let Some(result) = stream.next().await { - let owhisper_interface::ListenOutputChunk { words, .. } = result; - let text = words - .iter() - .map(|w| w.text.clone()) - .collect::>() - .join(" "); - println!("- {}", text); - } - server_handle.abort(); Ok(()) } diff --git a/crates/transcribe-whisper-local/src/service/streaming.rs b/crates/transcribe-whisper-local/src/service/streaming.rs index cd8fff59c..e4784f6d6 100644 --- a/crates/transcribe-whisper-local/src/service/streaming.rs +++ b/crates/transcribe-whisper-local/src/service/streaming.rs @@ -19,7 +19,7 @@ use tower::Service; use hypr_chunker::VadExt; use hypr_ws_utils::{ConnectionGuard, ConnectionManager}; -use owhisper_interface::{ListenOutputChunk, ListenParams, Word2}; +use owhisper_interface::{Alternatives, Channel, ListenParams, Metadata, StreamResponse, Word}; #[derive(Clone)] pub struct TranscribeService { @@ -200,37 +200,57 @@ async fn process_transcription_stream( let meta = chunk.meta(); let text = chunk.text().to_string(); - let start = chunk.start() as u64; - let duration = chunk.duration() as u64; - let confidence = chunk.confidence(); + let language = chunk.language().map(|s| s.to_string()).map(|s| vec![s]).unwrap_or_default(); + let start_f64 = chunk.start() as f64; + let duration_f64 = chunk.duration() as f64; + let confidence = chunk.confidence() as f64; let source = meta.and_then(|meta| meta.get("source") .and_then(|v| v.as_str()) .map(|s| s.to_string()) ); - let speaker = match source { - Some(s) if s == "mic" => Some(owhisper_interface::SpeakerIdentity::Unassigned { index: 0 }), - Some(s) if s == "speaker" => Some(owhisper_interface::SpeakerIdentity::Unassigned { index: 1 }), - _ => None, + + let (speaker, channel_index) = match source.as_deref() { + Some("mic") => (Some(0), vec![0]), + Some("speaker") => (Some(1), vec![1]), + _ => (None, vec![0]), }; - let data = ListenOutputChunk { - meta: None, - words: text - .split_whitespace() - .filter(|w| !w.is_empty()) - .map(|w| Word2 { - text: w.trim().to_string(), - speaker: speaker.clone(), - start_ms: Some(start), - end_ms: Some(start + duration), - confidence: Some(confidence), - }) - .collect(), + let words: Vec = text + .split_whitespace() + .filter(|w| !w.is_empty()) + .map(|w| Word { + word: w.to_string(), + start: start_f64, + end: start_f64 + duration_f64, + confidence, + speaker: speaker.clone(), + punctuated_word: None, + language: None, + }) + .collect(); + + let response = StreamResponse::TranscriptResponse { + type_field: "Results".to_string(), + start: start_f64, + duration: duration_f64, + is_final: true, + speech_final: true, + from_finalize: false, + channel: Channel{ + alternatives: vec![Alternatives{ + transcript: text.clone(), + languages: language.clone(), + words, + confidence, + }], + }, + metadata: Metadata::default(), + channel_index, }; - let msg = Message::Text(serde_json::to_string(&data).unwrap().into()); + let msg = Message::Text(serde_json::to_string(&response).unwrap().into()); if let Err(e) = ws_sender.send(msg).await { tracing::warn!("websocket_send_error: {}", e); break; diff --git a/crates/whisper-local/src/model.rs b/crates/whisper-local/src/model.rs index 70a94d1e5..4ed7a8ef8 100644 --- a/crates/whisper-local/src/model.rs +++ b/crates/whisper-local/src/model.rs @@ -161,6 +161,7 @@ impl Whisper { segments.push(Segment { text, + language: language.clone(), start: start as f32 / 1000.0, end: end as f32 / 1000.0, confidence, @@ -177,6 +178,7 @@ impl Whisper { .join(" "); if !full_text.is_empty() { + tracing::info!(text = ?full_text, "transcribe_completed"); self.dynamic_prompt = full_text; } @@ -304,6 +306,7 @@ impl Whisper { #[derive(Debug, Default)] pub struct Segment { pub text: String, + pub language: Option, pub start: f32, pub end: f32, pub confidence: f32, @@ -315,6 +318,10 @@ impl Segment { &self.text } + pub fn language(&self) -> Option<&str> { + self.language.as_deref() + } + pub fn start(&self) -> f32 { self.start } diff --git a/crates/ws-utils/src/lib.rs b/crates/ws-utils/src/lib.rs index 81a5a3ddb..4033f1c65 100644 --- a/crates/ws-utils/src/lib.rs +++ b/crates/ws-utils/src/lib.rs @@ -15,8 +15,38 @@ enum AudioProcessResult { End, } -fn process_ws_message(message: Message) -> AudioProcessResult { +fn deinterleave_audio(data: &[u8]) -> (Vec, Vec) { + let samples: Vec = data + .chunks_exact(2) + .map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]])) + .collect(); + + let mut mic = Vec::with_capacity(samples.len() / 2); + let mut speaker = Vec::with_capacity(samples.len() / 2); + + for chunk in samples.chunks_exact(2) { + mic.push(chunk[0] as f32 / 32768.0); + speaker.push(chunk[1] as f32 / 32768.0); + } + + (mic, speaker) +} + +fn process_ws_message(message: Message, channels: Option) -> AudioProcessResult { match message { + Message::Binary(data) => { + if data.is_empty() { + return AudioProcessResult::Empty; + } + + match channels { + Some(2) => { + let (mic, speaker) = deinterleave_audio(&data); + AudioProcessResult::DualSamples { mic, speaker } + } + _ => AudioProcessResult::Samples(bytes_to_f32_samples(&data)), + } + } Message::Text(data) => match serde_json::from_str::(&data) { Ok(ListenInputChunk::Audio { data }) => { if data.is_empty() { @@ -68,7 +98,7 @@ impl kalosm_sound::AsyncSource for WebSocketAudioSource { futures_util::stream::unfold(receiver, |receiver| async move { match receiver.next().await { - Some(Ok(message)) => match process_ws_message(message) { + Some(Ok(message)) => match process_ws_message(message, None) { AudioProcessResult::Samples(samples) => Some((samples, receiver)), AudioProcessResult::DualSamples { mic, speaker } => { let mixed = mix_audio_channels(&mic, &speaker); @@ -126,7 +156,7 @@ pub fn split_dual_audio_sources( tokio::spawn(async move { while let Some(Ok(message)) = ws_receiver.next().await { - match process_ws_message(message) { + match process_ws_message(message, Some(2)) { AudioProcessResult::Samples(samples) => { let _ = mic_tx.send(samples.clone()); let _ = speaker_tx.send(samples); diff --git a/owhisper/owhisper-client/src/lib.rs b/owhisper/owhisper-client/src/lib.rs index cf97e0dfd..695021e6c 100644 --- a/owhisper/owhisper-client/src/lib.rs +++ b/owhisper/owhisper-client/src/lib.rs @@ -1,7 +1,30 @@ use futures_util::{Stream, StreamExt}; use hypr_ws::client::{ClientRequestBuilder, Message, WebSocketClient, WebSocketIO}; -use owhisper_interface::{ListenInputChunk, ListenOutputChunk}; +use owhisper_interface::StreamResponse; + +fn interleave_audio(mic: &[u8], speaker: &[u8]) -> Vec { + let mic_samples: Vec = mic + .chunks_exact(2) + .map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]])) + .collect(); + let speaker_samples: Vec = speaker + .chunks_exact(2) + .map(|chunk| i16::from_le_bytes([chunk[0], chunk[1]])) + .collect(); + + let max_len = mic_samples.len().max(speaker_samples.len()); + let mut interleaved = Vec::with_capacity(max_len * 2 * 2); + + for i in 0..max_len { + let mic_sample = mic_samples.get(i).copied().unwrap_or(0); + let speaker_sample = speaker_samples.get(i).copied().unwrap_or(0); + interleaved.extend_from_slice(&mic_sample.to_le_bytes()); + interleaved.extend_from_slice(&speaker_sample.to_le_bytes()); + } + + interleaved +} #[derive(Default)] pub struct ListenClientBuilder { @@ -49,11 +72,18 @@ impl ListenClientBuilder { for (i, lang) in params.languages.iter().enumerate() { query_pairs.append_pair(&format!("languages[{}]", i), lang.iso639().code()); } + + let channels = match params.audio_mode { + owhisper_interface::AudioMode::Single => "1", + owhisper_interface::AudioMode::Dual => "2", + }; + query_pairs // https://developers.deepgram.com/reference/speech-to-text-api/listen-streaming#handshake .append_pair("model", ¶ms.model.unwrap_or("hypr-whisper".to_string())) .append_pair("sample_rate", "16000") .append_pair("encoding", "linear16") + .append_pair("channels", channels) .append_pair("audio_mode", params.audio_mode.as_ref()) .append_pair("static_prompt", ¶ms.static_prompt) .append_pair("dynamic_prompt", ¶ms.dynamic_prompt) @@ -103,17 +133,15 @@ pub struct ListenClient { impl WebSocketIO for ListenClient { type Data = bytes::Bytes; - type Input = ListenInputChunk; - type Output = ListenOutputChunk; + type Input = bytes::Bytes; + type Output = StreamResponse; fn to_input(data: Self::Data) -> Self::Input { - ListenInputChunk::Audio { - data: data.to_vec(), - } + data } fn to_message(input: Self::Input) -> Message { - Message::Text(serde_json::to_string(&input).unwrap().into()) + Message::Binary(input) } fn from_message(msg: Message) -> Option { @@ -131,18 +159,16 @@ pub struct ListenClientDual { impl WebSocketIO for ListenClientDual { type Data = (bytes::Bytes, bytes::Bytes); - type Input = ListenInputChunk; - type Output = ListenOutputChunk; + type Input = bytes::Bytes; + type Output = StreamResponse; fn to_input(data: Self::Data) -> Self::Input { - ListenInputChunk::DualAudio { - mic: data.0.to_vec(), - speaker: data.1.to_vec(), - } + let interleaved = interleave_audio(&data.0, &data.1); + bytes::Bytes::from(interleaved) } fn to_message(input: Self::Input) -> Message { - Message::Text(serde_json::to_string(&input).unwrap().into()) + Message::Binary(input) } fn from_message(msg: Message) -> Option { @@ -161,7 +187,7 @@ impl ListenClient { pub async fn from_realtime_audio( &self, audio_stream: impl Stream + Send + Unpin + 'static, - ) -> Result, hypr_ws::Error> { + ) -> Result, hypr_ws::Error> { let ws = WebSocketClient::new(self.request.clone()); ws.from_audio::(audio_stream).await } @@ -172,7 +198,7 @@ impl ListenClientDual { &self, mic_stream: impl Stream + Send + Unpin + 'static, speaker_stream: impl Stream + Send + Unpin + 'static, - ) -> Result, hypr_ws::Error> { + ) -> Result, hypr_ws::Error> { let dual_stream = mic_stream.zip(speaker_stream); let ws = WebSocketClient::new(self.request.clone()); ws.from_audio::(dual_stream).await @@ -187,7 +213,7 @@ mod tests { use hypr_audio_utils::AudioFormatExt; #[tokio::test] - #[ignore] + // cargo test -p owhisper-client test_client_deepgram -- --nocapture async fn test_client_deepgram() { let audio = rodio::Decoder::new(std::io::BufReader::new( std::fs::File::open(hypr_data::english_1::AUDIO_PATH).unwrap(), @@ -199,6 +225,7 @@ mod tests { .api_base("https://api.deepgram.com") .api_key(std::env::var("DEEPGRAM_API_KEY").unwrap()) .params(owhisper_interface::ListenParams { + model: Some("nova-2".to_string()), languages: vec![hypr_language::ISO639::En.into()], ..Default::default() }) @@ -212,6 +239,39 @@ mod tests { } } + #[tokio::test] + // cargo test -p owhisper-client test_client_ag -- --nocapture + async fn test_client_ag() { + let audio_1 = rodio::Decoder::new(std::io::BufReader::new( + std::fs::File::open(hypr_data::english_1::AUDIO_PATH).unwrap(), + )) + .unwrap() + .to_i16_le_chunks(16000, 512); + + let audio_2 = rodio::Decoder::new(std::io::BufReader::new( + std::fs::File::open(hypr_data::english_1::AUDIO_PATH).unwrap(), + )) + .unwrap() + .to_i16_le_chunks(16000, 512); + + let client = ListenClient::builder() + .api_base("ws://localhost:50060") + .api_key("".to_string()) + .params(owhisper_interface::ListenParams { + model: Some("tiny.en".to_string()), + languages: vec![hypr_language::ISO639::En.into()], + ..Default::default() + }) + .build_dual(); + + let stream = client.from_realtime_audio(audio_1, audio_2).await.unwrap(); + futures_util::pin_mut!(stream); + + while let Some(result) = stream.next().await { + println!("{:?}", result); + } + } + #[tokio::test] #[ignore] async fn test_client_owhisper() { @@ -225,6 +285,7 @@ mod tests { .api_base("ws://127.0.0.1:1234/v1/realtime") .api_key("".to_string()) .params(owhisper_interface::ListenParams { + model: Some("whisper-cpp".to_string()), languages: vec![hypr_language::ISO639::En.into()], ..Default::default() }) diff --git a/owhisper/owhisper-interface/Cargo.toml b/owhisper/owhisper-interface/Cargo.toml index 3f1afa456..1b3071825 100644 --- a/owhisper/owhisper-interface/Cargo.toml +++ b/owhisper/owhisper-interface/Cargo.toml @@ -12,9 +12,10 @@ hypr-language = { workspace = true } serde = { workspace = true, features = ["derive"] } serde_bytes = { workspace = true } serde_json = { workspace = true } -strum = { workspace = true, features = ["derive"] } chrono = { workspace = true, features = ["serde"] } codes-iso-639 = { workspace = true } schemars = { workspace = true } specta = { workspace = true, features = ["derive", "serde_json"] } +strum = { workspace = true, features = ["derive"] } +uuid = { workspace = true, features = ["v4"] } diff --git a/owhisper/owhisper-interface/src/lib.rs b/owhisper/owhisper-interface/src/lib.rs index 02210cd08..ae4d742c4 100644 --- a/owhisper/owhisper-interface/src/lib.rs +++ b/owhisper/owhisper-interface/src/lib.rs @@ -34,10 +34,9 @@ impl From for Word2 { fn from(word: Word) -> Self { Word2 { text: word.word, - speaker: word.speaker.map(|s| SpeakerIdentity::Assigned { - id: s.to_string(), - label: s.to_string(), - }), + speaker: word + .speaker + .map(|s| SpeakerIdentity::Unassigned { index: s as u8 }), confidence: Some(word.confidence as f32), start_ms: Some(word.start as u64), end_ms: Some(word.end as u64), diff --git a/owhisper/owhisper-interface/src/stream.rs b/owhisper/owhisper-interface/src/stream.rs index 6ac57e885..118750b48 100644 --- a/owhisper/owhisper-interface/src/stream.rs +++ b/owhisper/owhisper-interface/src/stream.rs @@ -47,6 +47,20 @@ common_derives! { } } +impl Default for Metadata { + fn default() -> Self { + Self { + request_id: uuid::Uuid::new_v4().to_string(), + model_uuid: uuid::Uuid::new_v4().to_string(), + model_info: ModelInfo { + name: "".to_string(), + version: "".to_string(), + arch: "".to_string(), + }, + } + } +} + common_derives! { #[serde(untagged)] #[non_exhaustive] @@ -92,7 +106,7 @@ mod test { #[test] fn ensure_types() { let dg = DG::StreamResponse::TranscriptResponse { - type_field: "transcript".to_string(), + type_field: "Results".to_string(), start: 0.0, duration: 0.0, is_final: false, diff --git a/owhisper/owhisper-server/src/commands/run/mod.rs b/owhisper/owhisper-server/src/commands/run/mod.rs index f54ee9f1c..389feec05 100644 --- a/owhisper/owhisper-server/src/commands/run/mod.rs +++ b/owhisper/owhisper-server/src/commands/run/mod.rs @@ -34,9 +34,8 @@ pub async fn handle_run(args: RunArgs) -> anyhow::Result<()> { log::set_max_level(log::LevelFilter::Off); let config = owhisper_config::Config::new(args.config.clone())?; - let api_key = config.general.as_ref().and_then(|g| g.api_key.clone()); - let server = Server::new(config, None); + let server = Server::new(config.clone(), None); let router = server.build_router().await?; let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?; diff --git a/owhisper/owhisper-server/src/commands/run/realtime.rs b/owhisper/owhisper-server/src/commands/run/realtime.rs index abd5e6ae2..438e6b824 100644 --- a/owhisper/owhisper-server/src/commands/run/realtime.rs +++ b/owhisper/owhisper-server/src/commands/run/realtime.rs @@ -23,7 +23,7 @@ pub async fn handle_realtime_input( let (event_tx, mut event_rx) = create_event_channel(); let (transcript_tx, transcript_rx) = - mpsc::unbounded_channel::(); + mpsc::unbounded_channel::(); let amplitude_data = Arc::new(Mutex::new(AmplitudeData::new())); @@ -93,7 +93,7 @@ fn start_audio_task( port: u16, api_key: Option, model: String, - transcript_tx: mpsc::UnboundedSender, + transcript_tx: mpsc::UnboundedSender, amplitude_data: Arc>, ) -> std::sync::Arc { let should_stop = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false)); @@ -123,7 +123,7 @@ async fn run_audio_stream_with_stop( port: u16, api_key: Option, model: String, - transcript_tx: mpsc::UnboundedSender, + transcript_tx: mpsc::UnboundedSender, amplitude_data: Arc>, should_stop: std::sync::Arc, ) -> anyhow::Result<()> { @@ -183,7 +183,7 @@ async fn run_tui_with_events( available_devices: Vec, amplitude_data: Arc>, event_tx: TuiEventSender, - mut transcript_rx: mpsc::UnboundedReceiver, + mut transcript_rx: mpsc::UnboundedReceiver, ) -> anyhow::Result<()> { use ratatui::crossterm::event::{self, Event, KeyCode}; use std::time::{Duration, Instant}; diff --git a/owhisper/owhisper-server/src/commands/run/state.rs b/owhisper/owhisper-server/src/commands/run/state.rs index 27c0f5790..2626eeaa9 100644 --- a/owhisper/owhisper-server/src/commands/run/state.rs +++ b/owhisper/owhisper-server/src/commands/run/state.rs @@ -56,13 +56,28 @@ impl RunState { self.event_sender = Some(sender); } - pub fn process_chunk(&mut self, chunk: owhisper_interface::ListenOutputChunk) { - if chunk.words.is_empty() { + pub fn process_chunk(&mut self, chunk: owhisper_interface::StreamResponse) { + let words = match chunk { + owhisper_interface::StreamResponse::TranscriptResponse { channel, .. } => channel + .alternatives + .first() + .map(|alt| { + alt.words + .iter() + .map(|w| owhisper_interface::Word2::from(w.clone())) + .collect::>() + }) + .unwrap_or_default(), + _ => { + return; + } + }; + + if words.is_empty() { return; } - let text = chunk - .words + let text = words .iter() .map(|w| w.text.as_str()) .collect::>() diff --git a/owhisper/owhisper-server/src/main.rs b/owhisper/owhisper-server/src/main.rs index 13b59cae8..3c0c326b9 100644 --- a/owhisper/owhisper-server/src/main.rs +++ b/owhisper/owhisper-server/src/main.rs @@ -50,8 +50,9 @@ async fn main() -> Result<(), Box> { Commands::Serve(args) => commands::handle_serve(args).await, }; - if result.is_err() { - log::error!("{}", result.unwrap_err()); + if let Err(e) = result { + log::error!("{}", e); + std::process::exit(1); } Ok(()) diff --git a/plugins/connector/src/ext.rs b/plugins/connector/src/ext.rs index c951a6e39..0d5c8a2b6 100644 --- a/plugins/connector/src/ext.rs +++ b/plugins/connector/src/ext.rs @@ -212,14 +212,11 @@ impl> ConnectorPluginExt for T { } { - use tauri_plugin_local_stt::{LocalSttPluginExt, SharedState}; - - let api_base = if self.is_server_running().await { - let state = self.state::(); - let guard = state.lock().await; - guard.api_base.clone().unwrap() - } else { - self.start_server().await? + use tauri_plugin_local_stt::LocalSttPluginExt; + + let api_base = match self.get_api_base(None).await? { + Some(api_base) => api_base, + None => self.start_server(None).await?, }; let conn = ConnectionSTT::HyprLocal(Connection { diff --git a/plugins/listener/src/fsm.rs b/plugins/listener/src/fsm.rs index f1ee129e4..0ad98a690 100644 --- a/plugins/listener/src/fsm.rs +++ b/plugins/listener/src/fsm.rs @@ -473,19 +473,35 @@ impl Session { loop { match tokio::time::timeout(LISTEN_STREAM_TIMEOUT, listen_stream.next()).await { Ok(Some(result)) => { - let _meta = result.meta.clone(); + let words = match result { + owhisper_interface::StreamResponse::TranscriptResponse { + channel, + .. + } => channel + .alternatives + .first() + .map(|alt| { + alt.words + .iter() + .map(|w| owhisper_interface::Word2::from(w.clone())) + .collect::>() + }) + .unwrap_or_default(), + _ => { + continue; + } + }; - { - let updated_words = update_session(&app, &session.id, result.words) - .await - .unwrap(); + if !words.is_empty() { + let updated_words = + update_session(&app, &session.id, words).await.unwrap(); SessionEvent::Words { words: updated_words, } .emit(&app) + .unwrap(); } - .unwrap(); } Ok(None) => { tracing::info!("listen_stream_ended"); diff --git a/plugins/local-stt/Cargo.toml b/plugins/local-stt/Cargo.toml index 0b57bcfd4..c598d2590 100644 --- a/plugins/local-stt/Cargo.toml +++ b/plugins/local-stt/Cargo.toml @@ -40,6 +40,7 @@ tokio-tungstenite = { workspace = true } tower = { workspace = true } [dependencies] +hypr-am = { workspace = true } hypr-audio-utils = { workspace = true } hypr-file = { workspace = true } hypr-transcribe-moonshine = { workspace = true } diff --git a/plugins/local-stt/build.rs b/plugins/local-stt/build.rs index e3cbc1d08..5242dce59 100644 --- a/plugins/local-stt/build.rs +++ b/plugins/local-stt/build.rs @@ -1,13 +1,11 @@ const COMMANDS: &[&str] = &[ "models_dir", "list_ggml_backends", - "is_server_running", "is_model_downloaded", "is_model_downloading", "download_model", "start_server", "stop_server", - "restart_server", "get_current_model", "set_current_model", "list_supported_models", diff --git a/plugins/local-stt/js/bindings.gen.ts b/plugins/local-stt/js/bindings.gen.ts index 58b2abb84..fc4ce83aa 100644 --- a/plugins/local-stt/js/bindings.gen.ts +++ b/plugins/local-stt/js/bindings.gen.ts @@ -16,32 +16,29 @@ async listGgmlBackends() : Promise { async isServerRunning() : Promise { return await TAURI_INVOKE("plugin:local-stt|is_server_running"); }, -async isModelDownloaded(model: SupportedModel) : Promise { +async isModelDownloaded(model: WhisperModel) : Promise { return await TAURI_INVOKE("plugin:local-stt|is_model_downloaded", { model }); }, -async isModelDownloading(model: SupportedModel) : Promise { +async isModelDownloading(model: WhisperModel) : Promise { return await TAURI_INVOKE("plugin:local-stt|is_model_downloading", { model }); }, -async downloadModel(model: SupportedModel, channel: TAURI_CHANNEL) : Promise { +async downloadModel(model: WhisperModel, channel: TAURI_CHANNEL) : Promise { return await TAURI_INVOKE("plugin:local-stt|download_model", { model, channel }); }, -async listSupportedModels() : Promise { +async listSupportedModels() : Promise { return await TAURI_INVOKE("plugin:local-stt|list_supported_models"); }, -async getCurrentModel() : Promise { +async getCurrentModel() : Promise { return await TAURI_INVOKE("plugin:local-stt|get_current_model"); }, -async setCurrentModel(model: SupportedModel) : Promise { +async setCurrentModel(model: WhisperModel) : Promise { return await TAURI_INVOKE("plugin:local-stt|set_current_model", { model }); }, -async startServer() : Promise { - return await TAURI_INVOKE("plugin:local-stt|start_server"); +async startServer(serverType: ServerType | null) : Promise { + return await TAURI_INVOKE("plugin:local-stt|start_server", { serverType }); }, -async stopServer() : Promise { - return await TAURI_INVOKE("plugin:local-stt|stop_server"); -}, -async restartServer() : Promise { - return await TAURI_INVOKE("plugin:local-stt|restart_server"); +async stopServer(serverType: ServerType | null) : Promise { + return await TAURI_INVOKE("plugin:local-stt|stop_server", { serverType }); } } @@ -61,10 +58,11 @@ recordedProcessingEvent: "plugin:local-stt:recorded-processing-event" /** user-defined types **/ export type GgmlBackend = { kind: string; name: string; description: string; total_memory_mb: number; free_memory_mb: number } -export type RecordedProcessingEvent = { type: "progress"; current: number; total: number; word: Word } +export type RecordedProcessingEvent = { type: "progress"; current: number; total: number; word: Word2 } +export type ServerType = "internal" | "external" export type SpeakerIdentity = { type: "unassigned"; value: { index: number } } | { type: "assigned"; value: { id: string; label: string } } -export type SupportedModel = "QuantizedTiny" | "QuantizedTinyEn" | "QuantizedBase" | "QuantizedBaseEn" | "QuantizedSmall" | "QuantizedSmallEn" | "QuantizedLargeTurbo" -export type Word = { text: string; speaker: SpeakerIdentity | null; confidence: number | null; start_ms: number | null; end_ms: number | null } +export type WhisperModel = "QuantizedTiny" | "QuantizedTinyEn" | "QuantizedBase" | "QuantizedBaseEn" | "QuantizedSmall" | "QuantizedSmallEn" | "QuantizedLargeTurbo" +export type Word2 = { text: string; speaker: SpeakerIdentity | null; confidence: number | null; start_ms: number | null; end_ms: number | null } /** tauri-specta globals **/ diff --git a/plugins/local-stt/permissions/autogenerated/reference.md b/plugins/local-stt/permissions/autogenerated/reference.md index 8c9bc2f0b..53ad292dd 100644 --- a/plugins/local-stt/permissions/autogenerated/reference.md +++ b/plugins/local-stt/permissions/autogenerated/reference.md @@ -5,13 +5,11 @@ Default permissions for the plugin #### This default permission set includes the following: - `allow-models-dir` -- `allow-is-server-running` - `allow-is-model-downloaded` - `allow-is-model-downloading` - `allow-download-model` - `allow-start-server` - `allow-stop-server` -- `allow-restart-server` - `allow-get-current-model` - `allow-set-current-model` - `allow-list-supported-models` diff --git a/plugins/local-stt/permissions/default.toml b/plugins/local-stt/permissions/default.toml index 782b142a3..5ff190f8c 100644 --- a/plugins/local-stt/permissions/default.toml +++ b/plugins/local-stt/permissions/default.toml @@ -2,13 +2,11 @@ description = "Default permissions for the plugin" permissions = [ "allow-models-dir", - "allow-is-server-running", "allow-is-model-downloaded", "allow-is-model-downloading", "allow-download-model", "allow-start-server", "allow-stop-server", - "allow-restart-server", "allow-get-current-model", "allow-set-current-model", "allow-list-supported-models", diff --git a/plugins/local-stt/permissions/schemas/schema.json b/plugins/local-stt/permissions/schemas/schema.json index 03817bc0c..8f5a702f2 100644 --- a/plugins/local-stt/permissions/schemas/schema.json +++ b/plugins/local-stt/permissions/schemas/schema.json @@ -451,10 +451,10 @@ "markdownDescription": "Denies the stop_server command without any pre-configured scope." }, { - "description": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-models-dir`\n- `allow-is-server-running`\n- `allow-is-model-downloaded`\n- `allow-is-model-downloading`\n- `allow-download-model`\n- `allow-start-server`\n- `allow-stop-server`\n- `allow-restart-server`\n- `allow-get-current-model`\n- `allow-set-current-model`\n- `allow-list-supported-models`", + "description": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-models-dir`\n- `allow-is-model-downloaded`\n- `allow-is-model-downloading`\n- `allow-download-model`\n- `allow-start-server`\n- `allow-stop-server`\n- `allow-get-current-model`\n- `allow-set-current-model`\n- `allow-list-supported-models`", "type": "string", "const": "default", - "markdownDescription": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-models-dir`\n- `allow-is-server-running`\n- `allow-is-model-downloaded`\n- `allow-is-model-downloading`\n- `allow-download-model`\n- `allow-start-server`\n- `allow-stop-server`\n- `allow-restart-server`\n- `allow-get-current-model`\n- `allow-set-current-model`\n- `allow-list-supported-models`" + "markdownDescription": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-models-dir`\n- `allow-is-model-downloaded`\n- `allow-is-model-downloading`\n- `allow-download-model`\n- `allow-start-server`\n- `allow-stop-server`\n- `allow-get-current-model`\n- `allow-set-current-model`\n- `allow-list-supported-models`" } ] } diff --git a/plugins/local-stt/src/commands.rs b/plugins/local-stt/src/commands.rs index ebabf1af7..84acf9953 100644 --- a/plugins/local-stt/src/commands.rs +++ b/plugins/local-stt/src/commands.rs @@ -1,4 +1,4 @@ -use crate::LocalSttPluginExt; +use crate::{server::ServerType, LocalSttPluginExt}; use hypr_whisper_local_model::WhisperModel; use tauri::ipc::Channel; @@ -33,20 +33,6 @@ pub async fn list_supported_models() -> Result, String> { Ok(models) } -#[tauri::command] -#[specta::specta] -pub async fn list_custom_models( - app: tauri::AppHandle, -) -> Result, String> { - Ok(app.list_custom_models()) -} - -#[tauri::command] -#[specta::specta] -pub async fn is_server_running(app: tauri::AppHandle) -> bool { - app.is_server_running().await -} - #[tauri::command] #[specta::specta] pub async fn is_model_downloaded( @@ -98,19 +84,22 @@ pub fn set_current_model( #[tauri::command] #[specta::specta] -pub async fn start_server(app: tauri::AppHandle) -> Result { - app.start_server().await.map_err(|e| e.to_string()) -} - -#[tauri::command] -#[specta::specta] -pub async fn stop_server(app: tauri::AppHandle) -> Result<(), String> { - app.stop_server().await.map_err(|e| e.to_string()) +pub async fn start_server( + app: tauri::AppHandle, + server_type: Option, +) -> Result { + app.start_server(server_type) + .await + .map_err(|e| e.to_string()) } #[tauri::command] #[specta::specta] -pub async fn restart_server(app: tauri::AppHandle) -> Result { - app.stop_server().await.map_err(|e| e.to_string())?; - app.start_server().await.map_err(|e| e.to_string()) +pub async fn stop_server( + app: tauri::AppHandle, + server_type: Option, +) -> Result { + app.stop_server(server_type) + .await + .map_err(|e| e.to_string()) } diff --git a/plugins/local-stt/src/error.rs b/plugins/local-stt/src/error.rs index aa704acc8..e712c4371 100644 --- a/plugins/local-stt/src/error.rs +++ b/plugins/local-stt/src/error.rs @@ -4,6 +4,8 @@ pub type Result = std::result::Result; #[derive(Debug, thiserror::Error)] pub enum Error { + #[error(transparent)] + AmError(#[from] hypr_am::Error), #[error(transparent)] HyprFileError(#[from] hypr_file::Error), #[error(transparent)] diff --git a/plugins/local-stt/src/ext.rs b/plugins/local-stt/src/ext.rs index 416e8af1a..c1d1ee22c 100644 --- a/plugins/local-stt/src/ext.rs +++ b/plugins/local-stt/src/ext.rs @@ -7,19 +7,25 @@ use tauri_plugin_store2::StorePluginExt; use hypr_file::{download_file_parallel, DownloadProgress}; use hypr_whisper_local_model::WhisperModel; +use crate::server::{external, internal, ServerType}; + pub trait LocalSttPluginExt { fn local_stt_store(&self) -> tauri_plugin_store2::ScopedStore; fn models_dir(&self) -> PathBuf; - fn list_custom_models(&self) -> Vec; fn list_ggml_backends(&self) -> Vec; - fn api_base(&self) -> impl Future>; - - fn start_external_server(&self) -> impl Future>; - fn stop_external_server(&self) -> impl Future>; - fn is_server_running(&self) -> impl Future; - fn start_server(&self) -> impl Future>; - fn stop_server(&self) -> impl Future>; + fn get_api_base( + &self, + server_type: Option, + ) -> impl Future, crate::Error>>; + fn start_server( + &self, + server_type: Option, + ) -> impl Future>; + fn stop_server( + &self, + server_type: Option, + ) -> impl Future>; fn get_current_model(&self) -> Result; fn set_current_model(&self, model: WhisperModel) -> Result<(), crate::Error>; @@ -50,41 +56,6 @@ impl> LocalSttPluginExt for T { hypr_whisper_local::list_ggml_backends() } - async fn api_base(&self) -> Option { - let state = self.state::(); - let s = state.lock().await; - - s.api_base.clone() - } - - fn list_custom_models(&self) -> Vec { - let models_dir = self.models_dir(); - let mut models = Vec::new(); - - for entry in models_dir.read_dir().unwrap() { - let entry = entry.unwrap(); - let path = entry.path(); - - if path.is_file() { - let file_name = path.file_name().unwrap().to_str().unwrap().to_string(); - let default_models = vec![ - WhisperModel::QuantizedTiny, - WhisperModel::QuantizedSmall, - WhisperModel::QuantizedLargeTurbo, - ] - .iter() - .map(|model| model.file_name().to_string()) - .collect::>(); - - if !default_models.contains(&file_name) { - models.push(file_name); - } - } - } - - models - } - async fn is_model_downloaded(&self, model: &WhisperModel) -> Result { let model_path = self.models_dir().join(model.file_name()); @@ -103,68 +74,111 @@ impl> LocalSttPluginExt for T { } #[tracing::instrument(skip_all)] - async fn start_external_server(&self) -> Result { - let port = 8008; - let cmd = self - .shell() - .sidecar("pro-stt-server")? - .arg(format!("--port {}", port)); - - let (_rx, _child) = cmd.spawn()?; - Ok(format!("http://localhost:{}", port)) + async fn get_api_base( + &self, + server_type: Option, + ) -> Result, crate::Error> { + let state = self.state::(); + let guard = state.lock().await; + + let internal_api_base = guard.internal_server.as_ref().map(|s| s.api_base.clone()); + let external_api_base = guard.external_server.as_ref().map(|s| s.api_base.clone()); + + match server_type { + Some(ServerType::Internal) => Ok(internal_api_base), + Some(ServerType::External) => Ok(external_api_base), + None => { + if let Some(external_api_base) = external_api_base { + Ok(Some(external_api_base)) + } else if let Some(internal_api_base) = internal_api_base { + Ok(Some(internal_api_base)) + } else { + Ok(None) + } + } + } } #[tracing::instrument(skip_all)] - async fn stop_external_server(&self) -> Result<(), crate::Error> { - Ok(()) - } + async fn start_server(&self, server_type: Option) -> Result { + let t = server_type.unwrap_or(ServerType::Internal); - #[tracing::instrument(skip_all)] - async fn is_server_running(&self) -> bool { - let state = self.state::(); - let s = state.lock().await; + match t { + ServerType::Internal => { + let cache_dir = self.models_dir(); + let model = self.get_current_model()?; - s.server.is_some() - } + if !self.is_model_downloaded(&model).await? { + return Err(crate::Error::ModelNotDownloaded); + } - #[tracing::instrument(skip_all)] - async fn start_server(&self) -> Result { - let cache_dir = self.models_dir(); - let model = self.get_current_model()?; + let server_state = internal::ServerState::builder() + .model_cache_dir(cache_dir) + .model_type(model) + .build(); - if !self.is_model_downloaded(&model).await? { - return Err(crate::Error::ModelNotDownloaded); - } + let server = internal::run_server(server_state).await?; + let api_base = server.api_base.clone(); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; - let server_state = crate::ServerState::builder() - .model_cache_dir(cache_dir) - .model_type(model) - .build(); + { + let state = self.state::(); + let mut s = state.lock().await; + s.internal_server = Some(server); + } - let server = crate::run_server(server_state).await?; - tokio::time::sleep(std::time::Duration::from_millis(100)).await; + Ok(api_base) + } + ServerType::External => { + let cmd = self.shell().sidecar("stt")?; - let api_base = format!("http://{}", &server.addr); + let server = external::run_server(cmd).await?; + let api_base = server.api_base.clone(); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; - { - let state = self.state::(); - let mut s = state.lock().await; - s.api_base = Some(api_base.clone()); - s.server = Some(server); - } + { + let state = self.state::(); + let mut s = state.lock().await; + s.external_server = Some(server); + } - Ok(api_base) + Ok(api_base) + } + } } #[tracing::instrument(skip_all)] - async fn stop_server(&self) -> Result<(), crate::Error> { + async fn stop_server(&self, server_type: Option) -> Result { let state = self.state::(); let mut s = state.lock().await; - if let Some(server) = s.server.take() { - let _ = server.shutdown.send(()); + let mut stopped = false; + match server_type { + Some(ServerType::External) => { + if let Some(server) = s.external_server.take() { + let _ = server.shutdown.send(()); + stopped = true; + } + } + Some(ServerType::Internal) => { + if let Some(server) = s.internal_server.take() { + let _ = server.shutdown.send(()); + stopped = true; + } + } + None => { + if let Some(server) = s.external_server.take() { + let _ = server.shutdown.send(()); + stopped = true; + } + if let Some(server) = s.internal_server.take() { + let _ = server.shutdown.send(()); + stopped = true; + } + } } - Ok(()) + + Ok(stopped) } #[tracing::instrument(skip_all)] diff --git a/plugins/local-stt/src/lib.rs b/plugins/local-stt/src/lib.rs index 79e4d3c1a..219e5714a 100644 --- a/plugins/local-stt/src/lib.rs +++ b/plugins/local-stt/src/lib.rs @@ -10,15 +10,15 @@ mod store; pub use error::*; pub use ext::*; -pub use server::*; pub use store::*; pub type SharedState = std::sync::Arc>; #[derive(Default)] pub struct State { - pub api_base: Option, - pub server: Option, + pub am_api_key: Option, + pub internal_server: Option, + pub external_server: Option, pub download_task: HashMap>, } @@ -30,7 +30,6 @@ fn make_specta_builder() -> tauri_specta::Builder { .commands(tauri_specta::collect_commands![ commands::models_dir::, commands::list_ggml_backends::, - commands::is_server_running::, commands::is_model_downloaded::, commands::is_model_downloading::, commands::download_model::, @@ -39,8 +38,8 @@ fn make_specta_builder() -> tauri_specta::Builder { commands::set_current_model::, commands::start_server::, commands::stop_server::, - commands::restart_server::, ]) + .typ::() .events(tauri_specta::collect_events![ events::RecordedProcessingEvent ]) @@ -79,7 +78,23 @@ pub fn init() -> tauri::plugin::TauriPlugin { } } - app.manage(SharedState::default()); + let api_key = { + #[cfg(not(debug_assertions))] + { + Some(env!("AM_API_KEY").to_string()) + } + + #[cfg(debug_assertions)] + { + option_env!("AM_API_KEY").map(|s| s.to_string()) + } + }; + + app.manage(SharedState::new(tokio::sync::Mutex::new(State { + am_api_key: api_key, + ..Default::default() + }))); + Ok(()) }) .build() @@ -121,8 +136,8 @@ mod test { use futures_util::StreamExt; let app = create_app(tauri::test::mock_builder()); - app.start_server().await.unwrap(); - let api_base = app.api_base().await.unwrap(); + app.start_server(None).await.unwrap(); + let api_base = app.get_api_base(None).await.unwrap().unwrap(); let listen_client = owhisper_client::ListenClient::builder() .api_base(api_base) @@ -149,6 +164,6 @@ mod test { println!("{:?}", chunk); } - app.stop_server().await.unwrap(); + app.stop_server(None).await.unwrap(); } } diff --git a/plugins/local-stt/src/server/external.rs b/plugins/local-stt/src/server/external.rs new file mode 100644 index 000000000..0dc6b13df --- /dev/null +++ b/plugins/local-stt/src/server/external.rs @@ -0,0 +1,21 @@ +#[derive(Clone)] +pub struct ServerHandle { + pub api_base: String, + pub shutdown: tokio::sync::watch::Sender<()>, + client: Option, +} + +pub async fn run_server( + cmd: tauri_plugin_shell::process::Command, +) -> Result { + let (_rx, _child) = cmd.args(["serve", "--port", "6942"]).spawn()?; + + let api_base = "http://localhost:6942"; + let client = hypr_am::AmClient::new(api_base); + + Ok(ServerHandle { + api_base: api_base.to_string(), + client: Some(client), + shutdown: tokio::sync::watch::channel(()).0, + }) +} diff --git a/plugins/local-stt/src/server.rs b/plugins/local-stt/src/server/internal.rs similarity index 97% rename from plugins/local-stt/src/server.rs rename to plugins/local-stt/src/server/internal.rs index 2e51a0bc7..4bdedb544 100644 --- a/plugins/local-stt/src/server.rs +++ b/plugins/local-stt/src/server/internal.rs @@ -47,7 +47,7 @@ impl ServerState { #[derive(Clone)] pub struct ServerHandle { - pub addr: SocketAddr, + pub api_base: String, pub shutdown: tokio::sync::watch::Sender<()>, } @@ -58,11 +58,12 @@ pub async fn run_server(state: ServerState) -> Result