diff --git a/apps/desktop/src/components/editor-area/note-header/listen-button.tsx b/apps/desktop/src/components/editor-area/note-header/listen-button.tsx index bd6856ec0..fa0096210 100644 --- a/apps/desktop/src/components/editor-area/note-header/listen-button.tsx +++ b/apps/desktop/src/components/editor-area/note-header/listen-button.tsx @@ -74,10 +74,11 @@ export default function ListenButton({ sessionId, isCompact = false }: { session refetchInterval: 1500, enabled: ongoingSessionStatus !== "running_active", queryFn: async () => { - const currentModel = await localSttCommands.getCurrentModel(); + const currentModel = await localSttCommands.getLocalModel(); const isDownloaded = await localSttCommands.isModelDownloaded(currentModel); const servers = await localSttCommands.getServers(); - const isServerAvailable = (servers.external === "ready") || (servers.internal === "ready"); + const isServerAvailable = (servers.external === "ready") || (servers.internal === "ready") + || (servers.custom === "ready"); return isDownloaded && isServerAvailable; }, }); diff --git a/apps/desktop/src/components/settings/components/ai/stt-view-local.tsx b/apps/desktop/src/components/settings/components/ai/stt-view-local.tsx index 811935c79..c03d50e12 100644 --- a/apps/desktop/src/components/settings/components/ai/stt-view-local.tsx +++ b/apps/desktop/src/components/settings/components/ai/stt-view-local.tsx @@ -28,6 +28,8 @@ const REFETCH_INTERVALS = { interface STTViewProps extends SharedSTTProps { isWerModalOpen: boolean; setIsWerModalOpen: (open: boolean) => void; + provider: "Local" | "Custom"; + setProviderToLocal: () => void; } interface ModelSectionProps { @@ -37,6 +39,8 @@ interface ModelSectionProps { setSelectedSTTModel: (model: string) => void; downloadingModels: Set; handleModelDownload: (model: string) => void; + provider: "Local" | "Custom"; + setProviderToLocal: () => void; } export function STTViewLocal({ @@ -46,22 +50,20 @@ export function STTViewLocal({ setSttModels, downloadingModels, handleModelDownload, + provider, + setProviderToLocal, }: STTViewProps) { const amAvailable = useMemo(() => platform() === "macos" && arch() === "aarch64", []); const servers = useQuery({ queryKey: ["local-stt-servers"], - queryFn: async () => { - const servers = await localSttCommands.getServers(); - console.log(servers); - return servers; - }, + queryFn: async () => localSttCommands.getServers(), refetchInterval: REFETCH_INTERVALS.servers, }); const currentSTTModel = useQuery({ queryKey: ["current-stt-model"], - queryFn: () => localSttCommands.getCurrentModel(), + queryFn: () => localSttCommands.getLocalModel(), }); const sttModelDownloadStatus = useQuery({ @@ -133,6 +135,8 @@ export function STTViewLocal({ setSelectedSTTModel={setSelectedSTTModel} downloadingModels={downloadingModels} handleModelDownload={handleModelDownload} + provider={provider} + setProviderToLocal={setProviderToLocal} /> {/* Divider - only show if pro models available */} @@ -147,6 +151,8 @@ export function STTViewLocal({ setSelectedSTTModel={setSelectedSTTModel} downloadingModels={downloadingModels} handleModelDownload={handleModelDownload} + provider={provider} + setProviderToLocal={setProviderToLocal} /> )} @@ -164,6 +170,8 @@ function BasicModelsSection({ setSelectedSTTModel, downloadingModels, handleModelDownload, + provider, + setProviderToLocal, }: ModelSectionProps) { const handleShowFileLocation = async () => { const path = await localSttCommands.modelsDir(); @@ -191,6 +199,8 @@ function BasicModelsSection({ downloadingModels={downloadingModels} handleModelDownload={handleModelDownload} handleShowFileLocation={handleShowFileLocation} + provider={provider} + setProviderToLocal={setProviderToLocal} /> ))} @@ -207,6 +217,8 @@ function ProModelsSection({ setSelectedSTTModel, downloadingModels, handleModelDownload, + provider, + setProviderToLocal, }: Omit) { const { getLicense } = useLicense(); @@ -258,6 +270,8 @@ function ProModelsSection({ downloadingModels={downloadingModels} handleModelDownload={handleModelDownload} handleShowFileLocation={handleShowFileLocation} + provider={provider} + setProviderToLocal={setProviderToLocal} /> ))} @@ -328,6 +342,8 @@ function ModelEntry({ downloadingModels, handleModelDownload, handleShowFileLocation, + provider, + setProviderToLocal, disabled = false, }: { model: STTModel; @@ -336,16 +352,20 @@ function ModelEntry({ downloadingModels: Set; handleModelDownload: (model: string) => void; handleShowFileLocation: () => void; + provider: "Local" | "Custom"; + setProviderToLocal: () => void; disabled?: boolean; }) { - const isSelected = selectedSTTModel === model.key && model.downloaded; + // only highlight if provider is Local and this is the selected model + const isSelected = provider === "Local" && selectedSTTModel === model.key && model.downloaded; const isSelectable = model.downloaded && !disabled; const isDownloading = downloadingModels.has(model.key); const handleClick = () => { if (isSelectable) { setSelectedSTTModel(model.key as SupportedSttModel); - localSttCommands.setCurrentModel(model.key as SupportedSttModel); + localSttCommands.setLocalModel(model.key as SupportedSttModel); + setProviderToLocal(); localSttCommands.stopServer(null); localSttCommands.startServer(null); } diff --git a/apps/desktop/src/components/settings/components/ai/stt-view-remote.tsx b/apps/desktop/src/components/settings/components/ai/stt-view-remote.tsx index f309e0b10..697ebc73b 100644 --- a/apps/desktop/src/components/settings/components/ai/stt-view-remote.tsx +++ b/apps/desktop/src/components/settings/components/ai/stt-view-remote.tsx @@ -1,40 +1,224 @@ -import { CloudIcon, ExternalLinkIcon } from "lucide-react"; +import { Trans } from "@lingui/react/macro"; +import { useMutation, useQuery } from "@tanstack/react-query"; +import { openUrl } from "@tauri-apps/plugin-opener"; +import { useEffect } from "react"; +import { useForm } from "react-hook-form"; + +import { commands as localSttCommands } from "@hypr/plugin-local-stt"; +import { Form, FormControl, FormDescription, FormField, FormItem, FormMessage } from "@hypr/ui/components/ui/form"; +import { Input } from "@hypr/ui/components/ui/input"; +import { cn } from "@hypr/ui/lib/utils"; + +export function STTViewRemote({ + provider, + setProviderToCustom, +}: { + provider: "Local" | "Custom"; + setProviderToCustom: () => void; +}) { + const apiBaseQuery = useQuery({ + queryKey: ["custom-stt-base-url"], + queryFn: () => localSttCommands.getCustomBaseUrl(), + }); + + const apiKeyQuery = useQuery({ + queryKey: ["custom-stt-api-key"], + queryFn: () => localSttCommands.getCustomApiKey(), + }); + + const modelQuery = useQuery({ + queryKey: ["custom-stt-model"], + queryFn: () => localSttCommands.getCustomModel(), + }); + + const setApiBaseMutation = useMutation({ + mutationFn: (apiBase: string) => localSttCommands.setCustomBaseUrl(apiBase), + onSuccess: () => apiBaseQuery.refetch(), + }); + + const setApiKeyMutation = useMutation({ + mutationFn: (apiKey: string) => localSttCommands.setCustomApiKey(apiKey), + onSuccess: () => apiKeyQuery.refetch(), + }); + + const setModelMutation = useMutation({ + mutationFn: (model: string) => localSttCommands.setCustomModel(model), + onSuccess: () => modelQuery.refetch(), + }); + + const form = useForm({ + defaultValues: { + api_base: "", + api_key: "", + model: "", + }, + }); + + useEffect(() => { + form.reset({ + api_base: apiBaseQuery.data || "", + api_key: apiKeyQuery.data || "", + model: modelQuery.data || "", + }); + }, [apiBaseQuery.data, apiKeyQuery.data, modelQuery.data, form]); + + useEffect(() => { + const subscription = form.watch((values, { name }) => { + if (name === "api_base") { + setApiBaseMutation.mutate(values.api_base || ""); + } + if (name === "api_key") { + setApiKeyMutation.mutate(values.api_key || ""); + } + if (name === "model") { + setModelMutation.mutate(values.model || ""); + } + }); + return () => subscription.unsubscribe(); + }, [form.watch, setApiBaseMutation, setApiKeyMutation, setModelMutation]); + + const isSelected = provider === "Custom"; -export function STTViewRemote() { return ( -
-
- -

- Custom Transcription -

-

- Coming Soon -

-
+
+
+ {/* Custom STT Endpoint Box */} +
{ + setProviderToCustom(); + }} + > +
+
+
+
+ + Custom Speech-to-Text endpoint + + + Preview + +
+

+ + Connect to{" "} + openUrl("https://deepgram.com")} + > + Deepgram + {" "} + directly, or use{" "} + openUrl("https://docs.hyprnote.com/owhisper/what-is-this")} + > + OWhisper + {" "} + for other provider support. + +

+
+
+
+ +
+
+
+ + {/* Base URL Section */} +
+

+ Base URL +

+ ( + + + Enter the base URL for your custom STT endpoint + + + e.stopPropagation()} + onFocus={() => setProviderToCustom()} + /> + + + + )} + /> +
+ + {/* API Key Section */} +
+

+ API Key +

+ ( + + + Your authentication key for accessing the STT service + + + e.stopPropagation()} + onFocus={() => setProviderToCustom()} + /> + + + + )} + /> +
-
-

- Powered by{" "} - - Owhisper - - -

-

- Interested in team features?{" "} - - Contact help@hyprnote.com - -

+ {/* Model Section */} +
+

+ Model +

+ ( + + + Enter the model name required by your STT endpoint + + + e.stopPropagation()} + onFocus={() => setProviderToCustom()} + /> + + + + )} + /> +
+ + +
+
+
); diff --git a/apps/desktop/src/components/settings/views/ai-stt.tsx b/apps/desktop/src/components/settings/views/ai-stt.tsx index a61d8a129..483d87313 100644 --- a/apps/desktop/src/components/settings/views/ai-stt.tsx +++ b/apps/desktop/src/components/settings/views/ai-stt.tsx @@ -1,6 +1,6 @@ import { Trans } from "@lingui/react/macro"; -import { useQueryClient } from "@tanstack/react-query"; -import { useState } from "react"; +import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; +import { useEffect, useState } from "react"; import { commands as localSttCommands } from "@hypr/plugin-local-stt"; import { Tabs, TabsContent, TabsList, TabsTrigger } from "@hypr/ui/components/ui/tabs"; @@ -13,6 +13,39 @@ export default function SttAI() { const queryClient = useQueryClient(); const [activeTab, setActiveTab] = useState<"default" | "custom">("default"); + const providerQuery = useQuery({ + queryKey: ["stt-provider"], + queryFn: () => localSttCommands.getProvider(), + }); + + const setProviderMutation = useMutation({ + mutationFn: (provider: "Local" | "Custom") => { + if (provider === "Custom") { + localSttCommands.stopServer(null); + } + return localSttCommands.setProvider(provider); + }, + onSuccess: () => { + providerQuery.refetch(); + }, + onError: (error) => { + console.error("Failed to set provider:", error); + }, + }); + + const provider = providerQuery.data ?? "Local"; + + useEffect(() => { + if (provider === "Custom") { + setActiveTab("custom"); + } else { + setActiveTab("default"); + } + }, [provider]); + + const setProviderToLocal = () => setProviderMutation.mutate("Local"); + const setProviderToCustom = () => setProviderMutation.mutate("Custom"); + const [isWerModalOpen, setIsWerModalOpen] = useState(false); const [selectedSTTModel, setSelectedSTTModel] = useState("QuantizedTiny"); const [sttModels, setSttModels] = useState(initialSttModels); @@ -36,11 +69,17 @@ export default function SttAI() { }); setSelectedSTTModel(modelKey); - localSttCommands.setCurrentModel(modelKey as any); + localSttCommands.setLocalModel(modelKey as any); + setProviderToLocal(); }, queryClient); }; - const sttProps: SharedSTTProps & { isWerModalOpen: boolean; setIsWerModalOpen: (open: boolean) => void } = { + const sttProps: SharedSTTProps & { + isWerModalOpen: boolean; + setIsWerModalOpen: (open: boolean) => void; + provider: "Local" | "Custom"; + setProviderToLocal: () => void; + } = { selectedSTTModel, setSelectedSTTModel, sttModels, @@ -49,6 +88,8 @@ export default function SttAI() { handleModelDownload, isWerModalOpen, setIsWerModalOpen, + provider, + setProviderToLocal, }; return ( @@ -70,7 +111,7 @@ export default function SttAI() { - +
diff --git a/apps/desktop/src/components/toast/model-download.tsx b/apps/desktop/src/components/toast/model-download.tsx index c947199f1..e1d9aab7b 100644 --- a/apps/desktop/src/components/toast/model-download.tsx +++ b/apps/desktop/src/components/toast/model-download.tsx @@ -15,7 +15,7 @@ export default function ModelDownloadNotification() { }); const currentSttModel = useQuery({ queryKey: ["current-stt-model"], - queryFn: () => localSttCommands.getCurrentModel(), + queryFn: () => localSttCommands.getLocalModel(), }); const currentLlmModel = useQuery({ diff --git a/apps/desktop/src/components/toast/model-select.tsx b/apps/desktop/src/components/toast/model-select.tsx index c757c4e26..5faa91b5f 100644 --- a/apps/desktop/src/components/toast/model-select.tsx +++ b/apps/desktop/src/components/toast/model-select.tsx @@ -6,7 +6,7 @@ import { Button } from "@hypr/ui/components/ui/button"; import { sonnerToast, toast } from "@hypr/ui/components/ui/toast"; export async function showModelSelectToast(language: string) { - const currentModel = await localSttCommands.getCurrentModel(); + const currentModel = await localSttCommands.getLocalModel(); const englishModels: SupportedSttModel[] = ["QuantizedTinyEn", "QuantizedBaseEn", "QuantizedSmallEn"]; if (language === "en" || !englishModels.includes(currentModel)) { diff --git a/apps/desktop/src/components/toast/shared.tsx b/apps/desktop/src/components/toast/shared.tsx index 9057d6da1..528b43f17 100644 --- a/apps/desktop/src/components/toast/shared.tsx +++ b/apps/desktop/src/components/toast/shared.tsx @@ -80,7 +80,7 @@ export function showSttModelDownloadToast( channel={sttChannel} onComplete={() => { sonnerToast.dismiss(id); - localSttCommands.setCurrentModel(model); + localSttCommands.setLocalModel(model); localSttCommands.startServer(null); if (onComplete) { onComplete(); diff --git a/apps/desktop/src/components/welcome-modal/download-progress-view.tsx b/apps/desktop/src/components/welcome-modal/download-progress-view.tsx index a08666f7e..e75a04e5d 100644 --- a/apps/desktop/src/components/welcome-modal/download-progress-view.tsx +++ b/apps/desktop/src/components/welcome-modal/download-progress-view.tsx @@ -172,7 +172,7 @@ export const DownloadProgressView = ({ const handleSttCompletion = async () => { if (sttDownload.completed) { try { - await localSttCommands.setCurrentModel(selectedSttModel); + await localSttCommands.setLocalModel(selectedSttModel); await localSttCommands.startServer(null); } catch (error) { console.error("Error setting up STT:", error); diff --git a/apps/desktop/src/components/welcome-modal/index.tsx b/apps/desktop/src/components/welcome-modal/index.tsx index d0997f4ae..fc4ec836f 100644 --- a/apps/desktop/src/components/welcome-modal/index.tsx +++ b/apps/desktop/src/components/welcome-modal/index.tsx @@ -76,7 +76,7 @@ export function WelcomeModal({ isOpen, onClose }: WelcomeModalProps) { /* const selectSTTModel = useMutation({ - mutationFn: (model: WhisperModel) => localSttCommands.setCurrentModel(model), + mutationFn: (model: WhisperModel) => localSttCommands.setLocalModel(model), }); */ diff --git a/apps/desktop/src/locales/en/messages.po b/apps/desktop/src/locales/en/messages.po index bdfdcafac..3b36e2ae4 100644 --- a/apps/desktop/src/locales/en/messages.po +++ b/apps/desktop/src/locales/en/messages.po @@ -267,10 +267,10 @@ msgstr "(Optional)" #. placeholder {0}: disabled ? "Wait..." : "Play again" #. placeholder {0}: disabled ? "Wait..." : "Play video" #. placeholder {0}: disabled ? "Wait..." : isHovered ? (isCompact ? "Resume" : "Resume") : (isCompact ? "Ended" : "Ended") -#: src/components/editor-area/note-header/listen-button.tsx:139 -#: src/components/editor-area/note-header/listen-button.tsx:217 -#: src/components/editor-area/note-header/listen-button.tsx:241 -#: src/components/editor-area/note-header/listen-button.tsx:261 +#: src/components/editor-area/note-header/listen-button.tsx:140 +#: src/components/editor-area/note-header/listen-button.tsx:218 +#: src/components/editor-area/note-header/listen-button.tsx:242 +#: src/components/editor-area/note-header/listen-button.tsx:262 #: src/components/settings/views/templates.tsx:252 msgid "{0}" msgstr "{0}" @@ -407,6 +407,7 @@ msgstr "API Base URL" #: src/components/settings/components/ai/llm-custom-view.tsx:409 #: src/components/settings/components/ai/llm-custom-view.tsx:514 #: src/components/settings/components/ai/llm-custom-view.tsx:625 +#: src/components/settings/components/ai/stt-view-remote.tsx:166 #: src/components/settings/views/integrations.tsx:203 #: src/components/welcome-modal/custom-endpoint-view.tsx:294 #: src/components/welcome-modal/custom-endpoint-view.tsx:361 @@ -452,6 +453,7 @@ msgstr "Back" msgid "Base Folder" msgstr "Base Folder" +#: src/components/settings/components/ai/stt-view-remote.tsx:138 #: src/components/settings/views/integrations.tsx:179 msgid "Base URL" msgstr "Base URL" @@ -565,6 +567,10 @@ msgstr "Complete the configuration to continue" #~ msgid "Connect" #~ msgstr "Connect" +#: src/components/settings/components/ai/stt-view-remote.tsx:109 +msgid "Connect to <0>Deepgram directly, or use <1>OWhisper for other provider support." +msgstr "Connect to <0>Deepgram directly, or use <1>OWhisper for other provider support." + #: src/components/settings/components/ai/llm-custom-view.tsx:583 msgid "Connect to a self-hosted or third-party LLM endpoint (OpenAI API compatible)" msgstr "Connect to a self-hosted or third-party LLM endpoint (OpenAI API compatible)" @@ -573,6 +579,10 @@ msgstr "Connect to a self-hosted or third-party LLM endpoint (OpenAI API compati #~ msgid "Connect to a self-hosted or third-party LLM endpoint (OpenAI API compatible)." #~ msgstr "Connect to a self-hosted or third-party LLM endpoint (OpenAI API compatible)." +#: src/components/settings/components/ai/stt-view-remote.tsx:111 +#~ msgid "Connect to a self-hosted or third-party STT endpoint (Deepgram compatible)" +#~ msgstr "Connect to a self-hosted or third-party STT endpoint (Deepgram compatible)" + #: src/components/settings/views/integrations.tsx:127 msgid "Connect with external tools and services to enhance your workflow" msgstr "Connect with external tools and services to enhance your workflow" @@ -660,7 +670,7 @@ msgstr "Create your first template to get started" #~ msgstr "Current Plan" #: src/components/settings/views/ai-llm.tsx:671 -#: src/components/settings/views/ai-stt.tsx:66 +#: src/components/settings/views/ai-stt.tsx:107 msgid "Custom" msgstr "Custom" @@ -672,12 +682,16 @@ msgstr "Custom" #~ msgid "Custom Endpoints" #~ msgstr "Custom Endpoints" +#: src/components/settings/components/ai/stt-view-remote.tsx:102 +msgid "Custom Speech-to-Text endpoint" +msgstr "Custom Speech-to-Text endpoint" + #: src/components/settings/views/general.tsx:422 msgid "Custom Vocabulary" msgstr "Custom Vocabulary" #: src/components/settings/views/ai-llm.tsx:668 -#: src/components/settings/views/ai-stt.tsx:63 +#: src/components/settings/views/ai-stt.tsx:104 msgid "Default" msgstr "Default" @@ -787,10 +801,18 @@ msgstr "Enter the base URL for your custom LLM endpoint" #~ msgid "Enter the base URL for your custom LLM endpoint (e.g., http://localhost:8080/v1)" #~ msgstr "Enter the base URL for your custom LLM endpoint (e.g., http://localhost:8080/v1)" +#: src/components/settings/components/ai/stt-view-remote.tsx:146 +msgid "Enter the base URL for your custom STT endpoint" +msgstr "Enter the base URL for your custom STT endpoint" + #: src/components/settings/views/ai.tsx:498 #~ msgid "Enter the exact model name required by your endpoint (if applicable)." #~ msgstr "Enter the exact model name required by your endpoint (if applicable)." +#: src/components/settings/components/ai/stt-view-remote.tsx:202 +msgid "Enter the model name required by your STT endpoint" +msgstr "Enter the model name required by your STT endpoint" + #: src/routes/app.settings.tsx:72 #~ msgid "Extensions" #~ msgstr "Extensions" @@ -1055,6 +1077,7 @@ msgstr "Microphone Access" #: src/components/settings/components/ai/llm-custom-view.tsx:334 #: src/components/settings/components/ai/llm-custom-view.tsx:429 #: src/components/settings/components/ai/llm-custom-view.tsx:534 +#: src/components/settings/components/ai/stt-view-remote.tsx:194 #: src/components/welcome-modal/custom-endpoint-view.tsx:315 #: src/components/welcome-modal/custom-endpoint-view.tsx:382 #: src/components/welcome-modal/custom-endpoint-view.tsx:459 @@ -1139,7 +1162,7 @@ msgstr "No recent notes with this organization" #~ msgid "No Template" #~ msgstr "No Template" -#: src/components/editor-area/note-header/listen-button.tsx:516 +#: src/components/editor-area/note-header/listen-button.tsx:517 msgid "No Template (Default)" msgstr "No Template (Default)" @@ -1227,7 +1250,7 @@ msgstr "Others" msgid "Owner" msgstr "Owner" -#: src/components/editor-area/note-header/listen-button.tsx:362 +#: src/components/editor-area/note-header/listen-button.tsx:363 msgid "Pause" msgstr "Pause" @@ -1327,7 +1350,7 @@ msgstr "Role" msgid "Save audio recording locally alongside the transcript." msgstr "Save audio recording locally alongside the transcript." -#: src/components/editor-area/note-header/listen-button.tsx:486 +#: src/components/editor-area/note-header/listen-button.tsx:487 msgid "Save current recording" msgstr "Save current recording" @@ -1476,11 +1499,11 @@ msgstr "Start automatically at login" #~ msgid "Start Monthly Plan" #~ msgstr "Start Monthly Plan" -#: src/components/editor-area/note-header/listen-button.tsx:188 +#: src/components/editor-area/note-header/listen-button.tsx:189 msgid "Start recording" msgstr "Start recording" -#: src/components/editor-area/note-header/listen-button.tsx:463 +#: src/components/editor-area/note-header/listen-button.tsx:464 msgid "Stop" msgstr "Stop" @@ -1529,7 +1552,7 @@ msgstr "Team management features are currently under development and will be ava msgid "Teamspace" msgstr "Teamspace" -#: src/components/editor-area/note-header/listen-button.tsx:507 +#: src/components/editor-area/note-header/listen-button.tsx:508 msgid "Template" msgstr "Template" @@ -1734,6 +1757,10 @@ msgstr "Where Conversations Stay Yours" msgid "Your API key for Obsidian local-rest-api plugin." msgstr "Your API key for Obsidian local-rest-api plugin." +#: src/components/settings/components/ai/stt-view-remote.tsx:174 +msgid "Your authentication key for accessing the STT service" +msgstr "Your authentication key for accessing the STT service" + #: src/components/settings/views/profile.tsx:213 msgid "Your LinkedIn username (the part after linkedin.com/in/)" msgstr "Your LinkedIn username (the part after linkedin.com/in/)" diff --git a/apps/desktop/src/locales/ko/messages.po b/apps/desktop/src/locales/ko/messages.po index 7708be4ed..2aa7427d5 100644 --- a/apps/desktop/src/locales/ko/messages.po +++ b/apps/desktop/src/locales/ko/messages.po @@ -267,10 +267,10 @@ msgstr "" #. placeholder {0}: disabled ? "Wait..." : "Play again" #. placeholder {0}: disabled ? "Wait..." : "Play video" #. placeholder {0}: disabled ? "Wait..." : isHovered ? (isCompact ? "Resume" : "Resume") : (isCompact ? "Ended" : "Ended") -#: src/components/editor-area/note-header/listen-button.tsx:139 -#: src/components/editor-area/note-header/listen-button.tsx:217 -#: src/components/editor-area/note-header/listen-button.tsx:241 -#: src/components/editor-area/note-header/listen-button.tsx:261 +#: src/components/editor-area/note-header/listen-button.tsx:140 +#: src/components/editor-area/note-header/listen-button.tsx:218 +#: src/components/editor-area/note-header/listen-button.tsx:242 +#: src/components/editor-area/note-header/listen-button.tsx:262 #: src/components/settings/views/templates.tsx:252 msgid "{0}" msgstr "" @@ -407,6 +407,7 @@ msgstr "" #: src/components/settings/components/ai/llm-custom-view.tsx:409 #: src/components/settings/components/ai/llm-custom-view.tsx:514 #: src/components/settings/components/ai/llm-custom-view.tsx:625 +#: src/components/settings/components/ai/stt-view-remote.tsx:166 #: src/components/settings/views/integrations.tsx:203 #: src/components/welcome-modal/custom-endpoint-view.tsx:294 #: src/components/welcome-modal/custom-endpoint-view.tsx:361 @@ -452,6 +453,7 @@ msgstr "" msgid "Base Folder" msgstr "" +#: src/components/settings/components/ai/stt-view-remote.tsx:138 #: src/components/settings/views/integrations.tsx:179 msgid "Base URL" msgstr "" @@ -565,6 +567,10 @@ msgstr "" #~ msgid "Connect" #~ msgstr "" +#: src/components/settings/components/ai/stt-view-remote.tsx:109 +msgid "Connect to <0>Deepgram directly, or use <1>OWhisper for other provider support." +msgstr "" + #: src/components/settings/components/ai/llm-custom-view.tsx:583 msgid "Connect to a self-hosted or third-party LLM endpoint (OpenAI API compatible)" msgstr "" @@ -573,6 +579,10 @@ msgstr "" #~ msgid "Connect to a self-hosted or third-party LLM endpoint (OpenAI API compatible)." #~ msgstr "" +#: src/components/settings/components/ai/stt-view-remote.tsx:111 +#~ msgid "Connect to a self-hosted or third-party STT endpoint (Deepgram compatible)" +#~ msgstr "" + #: src/components/settings/views/integrations.tsx:127 msgid "Connect with external tools and services to enhance your workflow" msgstr "" @@ -660,7 +670,7 @@ msgstr "" #~ msgstr "" #: src/components/settings/views/ai-llm.tsx:671 -#: src/components/settings/views/ai-stt.tsx:66 +#: src/components/settings/views/ai-stt.tsx:107 msgid "Custom" msgstr "" @@ -672,12 +682,20 @@ msgstr "" #~ msgid "Custom Endpoints" #~ msgstr "" +#: src/components/settings/components/ai/stt-view-remote.tsx:102 +msgid "Custom Speech-to-Text endpoint" +msgstr "" + +#: src/components/settings/components/ai/stt-view-remote.tsx:108 +#~ msgid "Custom STT Endpoint" +#~ msgstr "" + #: src/components/settings/views/general.tsx:422 msgid "Custom Vocabulary" msgstr "" #: src/components/settings/views/ai-llm.tsx:668 -#: src/components/settings/views/ai-stt.tsx:63 +#: src/components/settings/views/ai-stt.tsx:104 msgid "Default" msgstr "" @@ -787,10 +805,18 @@ msgstr "" #~ msgid "Enter the base URL for your custom LLM endpoint (e.g., http://localhost:8080/v1)" #~ msgstr "" +#: src/components/settings/components/ai/stt-view-remote.tsx:146 +msgid "Enter the base URL for your custom STT endpoint" +msgstr "" + #: src/components/settings/views/ai.tsx:498 #~ msgid "Enter the exact model name required by your endpoint (if applicable)." #~ msgstr "" +#: src/components/settings/components/ai/stt-view-remote.tsx:202 +msgid "Enter the model name required by your STT endpoint" +msgstr "" + #: src/routes/app.settings.tsx:72 #~ msgid "Extensions" #~ msgstr "" @@ -1055,6 +1081,7 @@ msgstr "" #: src/components/settings/components/ai/llm-custom-view.tsx:334 #: src/components/settings/components/ai/llm-custom-view.tsx:429 #: src/components/settings/components/ai/llm-custom-view.tsx:534 +#: src/components/settings/components/ai/stt-view-remote.tsx:194 #: src/components/welcome-modal/custom-endpoint-view.tsx:315 #: src/components/welcome-modal/custom-endpoint-view.tsx:382 #: src/components/welcome-modal/custom-endpoint-view.tsx:459 @@ -1139,7 +1166,7 @@ msgstr "" #~ msgid "No Template" #~ msgstr "" -#: src/components/editor-area/note-header/listen-button.tsx:516 +#: src/components/editor-area/note-header/listen-button.tsx:517 msgid "No Template (Default)" msgstr "" @@ -1227,7 +1254,7 @@ msgstr "" msgid "Owner" msgstr "" -#: src/components/editor-area/note-header/listen-button.tsx:362 +#: src/components/editor-area/note-header/listen-button.tsx:363 msgid "Pause" msgstr "" @@ -1327,7 +1354,7 @@ msgstr "" msgid "Save audio recording locally alongside the transcript." msgstr "" -#: src/components/editor-area/note-header/listen-button.tsx:486 +#: src/components/editor-area/note-header/listen-button.tsx:487 msgid "Save current recording" msgstr "" @@ -1476,11 +1503,11 @@ msgstr "" #~ msgid "Start Monthly Plan" #~ msgstr "" -#: src/components/editor-area/note-header/listen-button.tsx:188 +#: src/components/editor-area/note-header/listen-button.tsx:189 msgid "Start recording" msgstr "" -#: src/components/editor-area/note-header/listen-button.tsx:463 +#: src/components/editor-area/note-header/listen-button.tsx:464 msgid "Stop" msgstr "" @@ -1529,7 +1556,7 @@ msgstr "" msgid "Teamspace" msgstr "" -#: src/components/editor-area/note-header/listen-button.tsx:507 +#: src/components/editor-area/note-header/listen-button.tsx:508 msgid "Template" msgstr "" @@ -1734,6 +1761,10 @@ msgstr "" msgid "Your API key for Obsidian local-rest-api plugin." msgstr "" +#: src/components/settings/components/ai/stt-view-remote.tsx:174 +msgid "Your authentication key for accessing the STT service" +msgstr "" + #: src/components/settings/views/profile.tsx:213 msgid "Your LinkedIn username (the part after linkedin.com/in/)" msgstr "" diff --git a/apps/desktop/src/routes/app.tsx b/apps/desktop/src/routes/app.tsx index 8fd606ed3..5d3a29078 100644 --- a/apps/desktop/src/routes/app.tsx +++ b/apps/desktop/src/routes/app.tsx @@ -232,7 +232,7 @@ function RestartSTT() { return watch(sttPath, (_event) => { localSttCommands.stopServer(null).then((stopped) => { if (stopped) { - localSttCommands.getCurrentModel().then((model) => { + localSttCommands.getLocalModel().then((model) => { localSttCommands.startServer(model); }); } diff --git a/owhisper/owhisper-client/src/lib.rs b/owhisper/owhisper-client/src/lib.rs index 29aa0937d..2ce4b4419 100644 --- a/owhisper/owhisper-client/src/lib.rs +++ b/owhisper/owhisper-client/src/lib.rs @@ -98,6 +98,7 @@ impl ListenClientBuilder { .append_pair("interim_results", "true") .append_pair("sample_rate", "16000") .append_pair("encoding", "linear16") + .append_pair("multichannel", "true") .append_pair("channels", &channels.to_string()) .append_pair( "redemption_time_ms", diff --git a/owhisper/owhisper-server/src/server.rs b/owhisper/owhisper-server/src/server.rs index 2ba414d08..c8148c3f0 100644 --- a/owhisper/owhisper-server/src/server.rs +++ b/owhisper/owhisper-server/src/server.rs @@ -83,6 +83,7 @@ impl Server { .route("/health", axum::routing::get(health)) .route("/models", axum::routing::get(list_models)) .route("/v1/models", axum::routing::get(list_models)) + .route("/v1/status", axum::routing::get(status)) .with_state(app_state.clone()); let app = other_router @@ -269,6 +270,10 @@ async fn health() -> &'static str { "OK" } +async fn status() -> StatusCode { + StatusCode::NO_CONTENT +} + #[derive(serde::Serialize)] struct ModelInfo { id: String, diff --git a/plugins/listener/src/fsm.rs b/plugins/listener/src/fsm.rs index e6d09f6c8..fab0d05bc 100644 --- a/plugins/listener/src/fsm.rs +++ b/plugins/listener/src/fsm.rs @@ -615,6 +615,7 @@ async fn setup_listen_client( .api_base(conn.base_url) .api_key(conn.api_key.unwrap_or_default()) .params(owhisper_interface::ListenParams { + model: conn.model, languages, redemption_time_ms: Some(if is_onboarding { 60 } else { 400 }), ..Default::default() diff --git a/plugins/listener/src/manager.rs b/plugins/listener/src/manager.rs index 4f881529b..45966d9ce 100644 --- a/plugins/listener/src/manager.rs +++ b/plugins/listener/src/manager.rs @@ -112,6 +112,13 @@ impl TranscriptManager { ws }; + // needed for deepgram + if words.is_empty() { + return Diff { + final_words: HashMap::new(), + partial_words: self.partial_words_by_channel.clone(), + }; + } if is_final { let last_final_word_end = words.last().unwrap().end; diff --git a/plugins/local-stt/build.rs b/plugins/local-stt/build.rs index cc0b86fa9..e8b148218 100644 --- a/plugins/local-stt/build.rs +++ b/plugins/local-stt/build.rs @@ -7,10 +7,18 @@ const COMMANDS: &[&str] = &[ "start_server", "stop_server", "get_servers", - "get_current_model", - "set_current_model", + "get_local_model", + "set_local_model", "list_supported_models", "list_supported_languages", + "get_custom_base_url", + "get_custom_api_key", + "set_custom_base_url", + "set_custom_api_key", + "get_provider", + "set_provider", + "get_custom_model", + "set_custom_model", ]; fn main() { diff --git a/plugins/local-stt/js/bindings.gen.ts b/plugins/local-stt/js/bindings.gen.ts index 339ffcbd3..23441c5fa 100644 --- a/plugins/local-stt/js/bindings.gen.ts +++ b/plugins/local-stt/js/bindings.gen.ts @@ -22,11 +22,11 @@ async isModelDownloading(model: SupportedSttModel) : Promise { async downloadModel(model: SupportedSttModel, channel: TAURI_CHANNEL) : Promise { return await TAURI_INVOKE("plugin:local-stt|download_model", { model, channel }); }, -async getCurrentModel() : Promise { - return await TAURI_INVOKE("plugin:local-stt|get_current_model"); +async getLocalModel() : Promise { + return await TAURI_INVOKE("plugin:local-stt|get_local_model"); }, -async setCurrentModel(model: SupportedSttModel) : Promise { - return await TAURI_INVOKE("plugin:local-stt|set_current_model", { model }); +async setLocalModel(model: SupportedSttModel) : Promise { + return await TAURI_INVOKE("plugin:local-stt|set_local_model", { model }); }, async getServers() : Promise> { return await TAURI_INVOKE("plugin:local-stt|get_servers"); @@ -42,6 +42,30 @@ async listSupportedModels() : Promise { }, async listSupportedLanguages(model: SupportedSttModel) : Promise { return await TAURI_INVOKE("plugin:local-stt|list_supported_languages", { model }); +}, +async getCustomBaseUrl() : Promise { + return await TAURI_INVOKE("plugin:local-stt|get_custom_base_url"); +}, +async getCustomApiKey() : Promise { + return await TAURI_INVOKE("plugin:local-stt|get_custom_api_key"); +}, +async setCustomBaseUrl(baseUrl: string) : Promise { + return await TAURI_INVOKE("plugin:local-stt|set_custom_base_url", { baseUrl }); +}, +async setCustomApiKey(apiKey: string) : Promise { + return await TAURI_INVOKE("plugin:local-stt|set_custom_api_key", { apiKey }); +}, +async getProvider() : Promise { + return await TAURI_INVOKE("plugin:local-stt|get_provider"); +}, +async setProvider(provider: Provider) : Promise { + return await TAURI_INVOKE("plugin:local-stt|set_provider", { provider }); +}, +async getCustomModel() : Promise { + return await TAURI_INVOKE("plugin:local-stt|get_custom_model"); +}, +async setCustomModel(model: SupportedSttModel) : Promise { + return await TAURI_INVOKE("plugin:local-stt|set_custom_model", { model }); } } @@ -58,10 +82,11 @@ async listSupportedLanguages(model: SupportedSttModel) : Promise { export type AmModel = "am-parakeet-v2" | "am-parakeet-v3" | "am-whisper-large-v3" export type GgmlBackend = { kind: string; name: string; description: string; total_memory_mb: number; free_memory_mb: number } export type Language = { iso639: string } +export type Provider = "Local" | "Custom" export type ServerHealth = "unreachable" | "loading" | "ready" -export type ServerType = "internal" | "external" +export type ServerType = "internal" | "external" | "custom" export type SttModelInfo = { key: SupportedSttModel; display_name: string; size_bytes: number } -export type SupportedSttModel = WhisperModel | AmModel +export type SupportedSttModel = WhisperModel | AmModel | string export type TAURI_CHANNEL = null export type WhisperModel = "QuantizedTiny" | "QuantizedTinyEn" | "QuantizedBase" | "QuantizedBaseEn" | "QuantizedSmall" | "QuantizedSmallEn" | "QuantizedLargeTurbo" diff --git a/plugins/local-stt/permissions/autogenerated/commands/get_current_model.toml b/plugins/local-stt/permissions/autogenerated/commands/get_current_model.toml index 7cb3b360c..ba541bea1 100644 --- a/plugins/local-stt/permissions/autogenerated/commands/get_current_model.toml +++ b/plugins/local-stt/permissions/autogenerated/commands/get_current_model.toml @@ -4,10 +4,10 @@ [[permission]] identifier = "allow-get-current-model" -description = "Enables the get_current_model command without any pre-configured scope." -commands.allow = ["get_current_model"] +description = "Enables the get_local_model command without any pre-configured scope." +commands.allow = ["get_local_model"] [[permission]] identifier = "deny-get-current-model" -description = "Denies the get_current_model command without any pre-configured scope." -commands.deny = ["get_current_model"] +description = "Denies the get_local_model command without any pre-configured scope." +commands.deny = ["get_local_model"] diff --git a/plugins/local-stt/permissions/autogenerated/commands/get_custom_api_key.toml b/plugins/local-stt/permissions/autogenerated/commands/get_custom_api_key.toml new file mode 100644 index 000000000..db435e2ea --- /dev/null +++ b/plugins/local-stt/permissions/autogenerated/commands/get_custom_api_key.toml @@ -0,0 +1,13 @@ +# Automatically generated - DO NOT EDIT! + +"$schema" = "../../schemas/schema.json" + +[[permission]] +identifier = "allow-get-custom-api-key" +description = "Enables the get_custom_api_key command without any pre-configured scope." +commands.allow = ["get_custom_api_key"] + +[[permission]] +identifier = "deny-get-custom-api-key" +description = "Denies the get_custom_api_key command without any pre-configured scope." +commands.deny = ["get_custom_api_key"] diff --git a/plugins/local-stt/permissions/autogenerated/commands/get_custom_base_url.toml b/plugins/local-stt/permissions/autogenerated/commands/get_custom_base_url.toml new file mode 100644 index 000000000..7d54e3969 --- /dev/null +++ b/plugins/local-stt/permissions/autogenerated/commands/get_custom_base_url.toml @@ -0,0 +1,13 @@ +# Automatically generated - DO NOT EDIT! + +"$schema" = "../../schemas/schema.json" + +[[permission]] +identifier = "allow-get-custom-base-url" +description = "Enables the get_custom_base_url command without any pre-configured scope." +commands.allow = ["get_custom_base_url"] + +[[permission]] +identifier = "deny-get-custom-base-url" +description = "Denies the get_custom_base_url command without any pre-configured scope." +commands.deny = ["get_custom_base_url"] diff --git a/plugins/local-stt/permissions/autogenerated/commands/get_custom_model.toml b/plugins/local-stt/permissions/autogenerated/commands/get_custom_model.toml new file mode 100644 index 000000000..63af9efff --- /dev/null +++ b/plugins/local-stt/permissions/autogenerated/commands/get_custom_model.toml @@ -0,0 +1,13 @@ +# Automatically generated - DO NOT EDIT! + +"$schema" = "../../schemas/schema.json" + +[[permission]] +identifier = "allow-get-custom-model" +description = "Enables the get_custom_model command without any pre-configured scope." +commands.allow = ["get_custom_model"] + +[[permission]] +identifier = "deny-get-custom-model" +description = "Denies the get_custom_model command without any pre-configured scope." +commands.deny = ["get_custom_model"] diff --git a/plugins/local-stt/permissions/autogenerated/commands/get_local_model.toml b/plugins/local-stt/permissions/autogenerated/commands/get_local_model.toml new file mode 100644 index 000000000..ee2c1fbb9 --- /dev/null +++ b/plugins/local-stt/permissions/autogenerated/commands/get_local_model.toml @@ -0,0 +1,13 @@ +# Automatically generated - DO NOT EDIT! + +"$schema" = "../../schemas/schema.json" + +[[permission]] +identifier = "allow-get-local-model" +description = "Enables the get_local_model command without any pre-configured scope." +commands.allow = ["get_local_model"] + +[[permission]] +identifier = "deny-get-local-model" +description = "Denies the get_local_model command without any pre-configured scope." +commands.deny = ["get_local_model"] diff --git a/plugins/local-stt/permissions/autogenerated/commands/get_provider.toml b/plugins/local-stt/permissions/autogenerated/commands/get_provider.toml new file mode 100644 index 000000000..a342b2423 --- /dev/null +++ b/plugins/local-stt/permissions/autogenerated/commands/get_provider.toml @@ -0,0 +1,13 @@ +# Automatically generated - DO NOT EDIT! + +"$schema" = "../../schemas/schema.json" + +[[permission]] +identifier = "allow-get-provider" +description = "Enables the get_provider command without any pre-configured scope." +commands.allow = ["get_provider"] + +[[permission]] +identifier = "deny-get-provider" +description = "Denies the get_provider command without any pre-configured scope." +commands.deny = ["get_provider"] diff --git a/plugins/local-stt/permissions/autogenerated/commands/set_current_model.toml b/plugins/local-stt/permissions/autogenerated/commands/set_current_model.toml index 3a0093ca7..476402ad4 100644 --- a/plugins/local-stt/permissions/autogenerated/commands/set_current_model.toml +++ b/plugins/local-stt/permissions/autogenerated/commands/set_current_model.toml @@ -4,10 +4,10 @@ [[permission]] identifier = "allow-set-current-model" -description = "Enables the set_current_model command without any pre-configured scope." -commands.allow = ["set_current_model"] +description = "Enables the set_local_model command without any pre-configured scope." +commands.allow = ["set_local_model"] [[permission]] identifier = "deny-set-current-model" -description = "Denies the set_current_model command without any pre-configured scope." -commands.deny = ["set_current_model"] +description = "Denies the set_local_model command without any pre-configured scope." +commands.deny = ["set_local_model"] diff --git a/plugins/local-stt/permissions/autogenerated/commands/set_custom_api_key.toml b/plugins/local-stt/permissions/autogenerated/commands/set_custom_api_key.toml new file mode 100644 index 000000000..431900fd1 --- /dev/null +++ b/plugins/local-stt/permissions/autogenerated/commands/set_custom_api_key.toml @@ -0,0 +1,13 @@ +# Automatically generated - DO NOT EDIT! + +"$schema" = "../../schemas/schema.json" + +[[permission]] +identifier = "allow-set-custom-api-key" +description = "Enables the set_custom_api_key command without any pre-configured scope." +commands.allow = ["set_custom_api_key"] + +[[permission]] +identifier = "deny-set-custom-api-key" +description = "Denies the set_custom_api_key command without any pre-configured scope." +commands.deny = ["set_custom_api_key"] diff --git a/plugins/local-stt/permissions/autogenerated/commands/set_custom_base_url.toml b/plugins/local-stt/permissions/autogenerated/commands/set_custom_base_url.toml new file mode 100644 index 000000000..55f1d3090 --- /dev/null +++ b/plugins/local-stt/permissions/autogenerated/commands/set_custom_base_url.toml @@ -0,0 +1,13 @@ +# Automatically generated - DO NOT EDIT! + +"$schema" = "../../schemas/schema.json" + +[[permission]] +identifier = "allow-set-custom-base-url" +description = "Enables the set_custom_base_url command without any pre-configured scope." +commands.allow = ["set_custom_base_url"] + +[[permission]] +identifier = "deny-set-custom-base-url" +description = "Denies the set_custom_base_url command without any pre-configured scope." +commands.deny = ["set_custom_base_url"] diff --git a/plugins/local-stt/permissions/autogenerated/commands/set_custom_model.toml b/plugins/local-stt/permissions/autogenerated/commands/set_custom_model.toml new file mode 100644 index 000000000..9386aa641 --- /dev/null +++ b/plugins/local-stt/permissions/autogenerated/commands/set_custom_model.toml @@ -0,0 +1,13 @@ +# Automatically generated - DO NOT EDIT! + +"$schema" = "../../schemas/schema.json" + +[[permission]] +identifier = "allow-set-custom-model" +description = "Enables the set_custom_model command without any pre-configured scope." +commands.allow = ["set_custom_model"] + +[[permission]] +identifier = "deny-set-custom-model" +description = "Denies the set_custom_model command without any pre-configured scope." +commands.deny = ["set_custom_model"] diff --git a/plugins/local-stt/permissions/autogenerated/commands/set_local_model.toml b/plugins/local-stt/permissions/autogenerated/commands/set_local_model.toml new file mode 100644 index 000000000..e49e14c8d --- /dev/null +++ b/plugins/local-stt/permissions/autogenerated/commands/set_local_model.toml @@ -0,0 +1,13 @@ +# Automatically generated - DO NOT EDIT! + +"$schema" = "../../schemas/schema.json" + +[[permission]] +identifier = "allow-set-local-model" +description = "Enables the set_local_model command without any pre-configured scope." +commands.allow = ["set_local_model"] + +[[permission]] +identifier = "deny-set-local-model" +description = "Denies the set_local_model command without any pre-configured scope." +commands.deny = ["set_local_model"] diff --git a/plugins/local-stt/permissions/autogenerated/commands/set_provider.toml b/plugins/local-stt/permissions/autogenerated/commands/set_provider.toml new file mode 100644 index 000000000..df98094ab --- /dev/null +++ b/plugins/local-stt/permissions/autogenerated/commands/set_provider.toml @@ -0,0 +1,13 @@ +# Automatically generated - DO NOT EDIT! + +"$schema" = "../../schemas/schema.json" + +[[permission]] +identifier = "allow-set-provider" +description = "Enables the set_provider command without any pre-configured scope." +commands.allow = ["set_provider"] + +[[permission]] +identifier = "deny-set-provider" +description = "Denies the set_provider command without any pre-configured scope." +commands.deny = ["set_provider"] diff --git a/plugins/local-stt/permissions/autogenerated/reference.md b/plugins/local-stt/permissions/autogenerated/reference.md index b0da85400..2cd1551eb 100644 --- a/plugins/local-stt/permissions/autogenerated/reference.md +++ b/plugins/local-stt/permissions/autogenerated/reference.md @@ -15,6 +15,14 @@ Default permissions for the plugin - `allow-set-current-model` - `allow-list-supported-models` - `allow-list-supported-languages` +- `allow-get-custom-base-url` +- `allow-get-custom-api-key` +- `allow-set-custom-base-url` +- `allow-set-custom-api-key` +- `allow-get-provider` +- `allow-set-provider` +- `allow-get-custom-model` +- `allow-set-custom-model` ## Permission Table @@ -59,7 +67,7 @@ Denies the download_model command without any pre-configured scope. -Enables the get_current_model command without any pre-configured scope. +Enables the get_local_model command without any pre-configured scope. @@ -72,7 +80,85 @@ Enables the get_current_model command without any pre-configured scope. -Denies the get_current_model command without any pre-configured scope. +Denies the get_local_model command without any pre-configured scope. + + + + + + + +`local-stt:allow-get-custom-api-key` + + + + +Enables the get_custom_api_key command without any pre-configured scope. + + + + + + + +`local-stt:deny-get-custom-api-key` + + + + +Denies the get_custom_api_key command without any pre-configured scope. + + + + + + + +`local-stt:allow-get-custom-base-url` + + + + +Enables the get_custom_base_url command without any pre-configured scope. + + + + + + + +`local-stt:deny-get-custom-base-url` + + + + +Denies the get_custom_base_url command without any pre-configured scope. + + + + + + + +`local-stt:allow-get-custom-model` + + + + +Enables the get_custom_model command without any pre-configured scope. + + + + + + + +`local-stt:deny-get-custom-model` + + + + +Denies the get_custom_model command without any pre-configured scope. @@ -106,6 +192,58 @@ Denies the get_external_server_status command without any pre-configured scope. +`local-stt:allow-get-local-model` + + + + +Enables the get_local_model command without any pre-configured scope. + + + + + + + +`local-stt:deny-get-local-model` + + + + +Denies the get_local_model command without any pre-configured scope. + + + + + + + +`local-stt:allow-get-provider` + + + + +Enables the get_provider command without any pre-configured scope. + + + + + + + +`local-stt:deny-get-provider` + + + + +Denies the get_provider command without any pre-configured scope. + + + + + + + `local-stt:allow-get-servers` @@ -423,7 +561,7 @@ Denies the restart_server command without any pre-configured scope. -Enables the set_current_model command without any pre-configured scope. +Enables the set_local_model command without any pre-configured scope. @@ -436,7 +574,137 @@ Enables the set_current_model command without any pre-configured scope. -Denies the set_current_model command without any pre-configured scope. +Denies the set_local_model command without any pre-configured scope. + + + + + + + +`local-stt:allow-set-custom-api-key` + + + + +Enables the set_custom_api_key command without any pre-configured scope. + + + + + + + +`local-stt:deny-set-custom-api-key` + + + + +Denies the set_custom_api_key command without any pre-configured scope. + + + + + + + +`local-stt:allow-set-custom-base-url` + + + + +Enables the set_custom_base_url command without any pre-configured scope. + + + + + + + +`local-stt:deny-set-custom-base-url` + + + + +Denies the set_custom_base_url command without any pre-configured scope. + + + + + + + +`local-stt:allow-set-custom-model` + + + + +Enables the set_custom_model command without any pre-configured scope. + + + + + + + +`local-stt:deny-set-custom-model` + + + + +Denies the set_custom_model command without any pre-configured scope. + + + + + + + +`local-stt:allow-set-local-model` + + + + +Enables the set_local_model command without any pre-configured scope. + + + + + + + +`local-stt:deny-set-local-model` + + + + +Denies the set_local_model command without any pre-configured scope. + + + + + + + +`local-stt:allow-set-provider` + + + + +Enables the set_provider command without any pre-configured scope. + + + + + + + +`local-stt:deny-set-provider` + + + + +Denies the set_provider command without any pre-configured scope. diff --git a/plugins/local-stt/permissions/default.toml b/plugins/local-stt/permissions/default.toml index 8d078ad27..0bb62ef33 100644 --- a/plugins/local-stt/permissions/default.toml +++ b/plugins/local-stt/permissions/default.toml @@ -12,4 +12,12 @@ permissions = [ "allow-set-current-model", "allow-list-supported-models", "allow-list-supported-languages", + "allow-get-custom-base-url", + "allow-get-custom-api-key", + "allow-set-custom-base-url", + "allow-set-custom-api-key", + "allow-get-provider", + "allow-set-provider", + "allow-get-custom-model", + "allow-set-custom-model", ] diff --git a/plugins/local-stt/permissions/schemas/schema.json b/plugins/local-stt/permissions/schemas/schema.json index cc74b0708..231c0ff26 100644 --- a/plugins/local-stt/permissions/schemas/schema.json +++ b/plugins/local-stt/permissions/schemas/schema.json @@ -307,16 +307,52 @@ "markdownDescription": "Denies the download_model command without any pre-configured scope." }, { - "description": "Enables the get_current_model command without any pre-configured scope.", + "description": "Enables the get_local_model command without any pre-configured scope.", "type": "string", "const": "allow-get-current-model", - "markdownDescription": "Enables the get_current_model command without any pre-configured scope." + "markdownDescription": "Enables the get_local_model command without any pre-configured scope." }, { - "description": "Denies the get_current_model command without any pre-configured scope.", + "description": "Denies the get_local_model command without any pre-configured scope.", "type": "string", "const": "deny-get-current-model", - "markdownDescription": "Denies the get_current_model command without any pre-configured scope." + "markdownDescription": "Denies the get_local_model command without any pre-configured scope." + }, + { + "description": "Enables the get_custom_api_key command without any pre-configured scope.", + "type": "string", + "const": "allow-get-custom-api-key", + "markdownDescription": "Enables the get_custom_api_key command without any pre-configured scope." + }, + { + "description": "Denies the get_custom_api_key command without any pre-configured scope.", + "type": "string", + "const": "deny-get-custom-api-key", + "markdownDescription": "Denies the get_custom_api_key command without any pre-configured scope." + }, + { + "description": "Enables the get_custom_base_url command without any pre-configured scope.", + "type": "string", + "const": "allow-get-custom-base-url", + "markdownDescription": "Enables the get_custom_base_url command without any pre-configured scope." + }, + { + "description": "Denies the get_custom_base_url command without any pre-configured scope.", + "type": "string", + "const": "deny-get-custom-base-url", + "markdownDescription": "Denies the get_custom_base_url command without any pre-configured scope." + }, + { + "description": "Enables the get_custom_model command without any pre-configured scope.", + "type": "string", + "const": "allow-get-custom-model", + "markdownDescription": "Enables the get_custom_model command without any pre-configured scope." + }, + { + "description": "Denies the get_custom_model command without any pre-configured scope.", + "type": "string", + "const": "deny-get-custom-model", + "markdownDescription": "Denies the get_custom_model command without any pre-configured scope." }, { "description": "Enables the get_external_server_status command without any pre-configured scope.", @@ -330,6 +366,30 @@ "const": "deny-get-external-server-status", "markdownDescription": "Denies the get_external_server_status command without any pre-configured scope." }, + { + "description": "Enables the get_local_model command without any pre-configured scope.", + "type": "string", + "const": "allow-get-local-model", + "markdownDescription": "Enables the get_local_model command without any pre-configured scope." + }, + { + "description": "Denies the get_local_model command without any pre-configured scope.", + "type": "string", + "const": "deny-get-local-model", + "markdownDescription": "Denies the get_local_model command without any pre-configured scope." + }, + { + "description": "Enables the get_provider command without any pre-configured scope.", + "type": "string", + "const": "allow-get-provider", + "markdownDescription": "Enables the get_provider command without any pre-configured scope." + }, + { + "description": "Denies the get_provider command without any pre-configured scope.", + "type": "string", + "const": "deny-get-provider", + "markdownDescription": "Denies the get_provider command without any pre-configured scope." + }, { "description": "Enables the get_servers command without any pre-configured scope.", "type": "string", @@ -475,16 +535,76 @@ "markdownDescription": "Denies the restart_server command without any pre-configured scope." }, { - "description": "Enables the set_current_model command without any pre-configured scope.", + "description": "Enables the set_local_model command without any pre-configured scope.", "type": "string", "const": "allow-set-current-model", - "markdownDescription": "Enables the set_current_model command without any pre-configured scope." + "markdownDescription": "Enables the set_local_model command without any pre-configured scope." }, { - "description": "Denies the set_current_model command without any pre-configured scope.", + "description": "Denies the set_local_model command without any pre-configured scope.", "type": "string", "const": "deny-set-current-model", - "markdownDescription": "Denies the set_current_model command without any pre-configured scope." + "markdownDescription": "Denies the set_local_model command without any pre-configured scope." + }, + { + "description": "Enables the set_custom_api_key command without any pre-configured scope.", + "type": "string", + "const": "allow-set-custom-api-key", + "markdownDescription": "Enables the set_custom_api_key command without any pre-configured scope." + }, + { + "description": "Denies the set_custom_api_key command without any pre-configured scope.", + "type": "string", + "const": "deny-set-custom-api-key", + "markdownDescription": "Denies the set_custom_api_key command without any pre-configured scope." + }, + { + "description": "Enables the set_custom_base_url command without any pre-configured scope.", + "type": "string", + "const": "allow-set-custom-base-url", + "markdownDescription": "Enables the set_custom_base_url command without any pre-configured scope." + }, + { + "description": "Denies the set_custom_base_url command without any pre-configured scope.", + "type": "string", + "const": "deny-set-custom-base-url", + "markdownDescription": "Denies the set_custom_base_url command without any pre-configured scope." + }, + { + "description": "Enables the set_custom_model command without any pre-configured scope.", + "type": "string", + "const": "allow-set-custom-model", + "markdownDescription": "Enables the set_custom_model command without any pre-configured scope." + }, + { + "description": "Denies the set_custom_model command without any pre-configured scope.", + "type": "string", + "const": "deny-set-custom-model", + "markdownDescription": "Denies the set_custom_model command without any pre-configured scope." + }, + { + "description": "Enables the set_local_model command without any pre-configured scope.", + "type": "string", + "const": "allow-set-local-model", + "markdownDescription": "Enables the set_local_model command without any pre-configured scope." + }, + { + "description": "Denies the set_local_model command without any pre-configured scope.", + "type": "string", + "const": "deny-set-local-model", + "markdownDescription": "Denies the set_local_model command without any pre-configured scope." + }, + { + "description": "Enables the set_provider command without any pre-configured scope.", + "type": "string", + "const": "allow-set-provider", + "markdownDescription": "Enables the set_provider command without any pre-configured scope." + }, + { + "description": "Denies the set_provider command without any pre-configured scope.", + "type": "string", + "const": "deny-set-provider", + "markdownDescription": "Denies the set_provider command without any pre-configured scope." }, { "description": "Enables the start_server command without any pre-configured scope.", @@ -511,10 +631,10 @@ "markdownDescription": "Denies the stop_server command without any pre-configured scope." }, { - "description": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-models-dir`\n- `allow-is-model-downloaded`\n- `allow-is-model-downloading`\n- `allow-download-model`\n- `allow-start-server`\n- `allow-stop-server`\n- `allow-get-servers`\n- `allow-get-current-model`\n- `allow-set-current-model`\n- `allow-list-supported-models`\n- `allow-list-supported-languages`", + "description": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-models-dir`\n- `allow-is-model-downloaded`\n- `allow-is-model-downloading`\n- `allow-download-model`\n- `allow-start-server`\n- `allow-stop-server`\n- `allow-get-servers`\n- `allow-get-current-model`\n- `allow-set-current-model`\n- `allow-list-supported-models`\n- `allow-list-supported-languages`\n- `allow-get-custom-base-url`\n- `allow-get-custom-api-key`\n- `allow-set-custom-base-url`\n- `allow-set-custom-api-key`\n- `allow-get-provider`\n- `allow-set-provider`\n- `allow-get-custom-model`\n- `allow-set-custom-model`", "type": "string", "const": "default", - "markdownDescription": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-models-dir`\n- `allow-is-model-downloaded`\n- `allow-is-model-downloading`\n- `allow-download-model`\n- `allow-start-server`\n- `allow-stop-server`\n- `allow-get-servers`\n- `allow-get-current-model`\n- `allow-set-current-model`\n- `allow-list-supported-models`\n- `allow-list-supported-languages`" + "markdownDescription": "Default permissions for the plugin\n#### This default permission set includes:\n\n- `allow-models-dir`\n- `allow-is-model-downloaded`\n- `allow-is-model-downloading`\n- `allow-download-model`\n- `allow-start-server`\n- `allow-stop-server`\n- `allow-get-servers`\n- `allow-get-current-model`\n- `allow-set-current-model`\n- `allow-list-supported-models`\n- `allow-list-supported-languages`\n- `allow-get-custom-base-url`\n- `allow-get-custom-api-key`\n- `allow-set-custom-base-url`\n- `allow-set-custom-api-key`\n- `allow-get-provider`\n- `allow-set-provider`\n- `allow-get-custom-model`\n- `allow-set-custom-model`" } ] } diff --git a/plugins/local-stt/src/commands.rs b/plugins/local-stt/src/commands.rs index 3ed1e2743..cabdc18b2 100644 --- a/plugins/local-stt/src/commands.rs +++ b/plugins/local-stt/src/commands.rs @@ -60,21 +60,19 @@ pub async fn download_model( #[tauri::command] #[specta::specta] -pub fn get_current_model( +pub fn get_local_model( app: tauri::AppHandle, ) -> Result { - app.get_current_model().map_err(|e| e.to_string()) + app.get_local_model().map_err(|e| e.to_string()) } #[tauri::command] #[specta::specta] -pub async fn set_current_model( +pub async fn set_local_model( app: tauri::AppHandle, model: SupportedSttModel, ) -> Result<(), String> { - app.set_current_model(model) - .await - .map_err(|e| e.to_string()) + app.set_local_model(model).await.map_err(|e| e.to_string()) } #[tauri::command] @@ -110,3 +108,69 @@ pub async fn get_servers( pub fn list_supported_languages(model: SupportedSttModel) -> Vec { model.supported_languages() } + +#[tauri::command] +#[specta::specta] +pub fn get_custom_base_url(app: tauri::AppHandle) -> Result { + app.get_custom_base_url().map_err(|e| e.to_string()) +} + +#[tauri::command] +#[specta::specta] +pub fn get_custom_api_key( + app: tauri::AppHandle, +) -> Result, String> { + app.get_custom_api_key().map_err(|e| e.to_string()) +} + +#[tauri::command] +#[specta::specta] +pub fn set_custom_base_url( + app: tauri::AppHandle, + base_url: String, +) -> Result<(), String> { + app.set_custom_base_url(base_url).map_err(|e| e.to_string()) +} + +#[tauri::command] +#[specta::specta] +pub fn set_custom_api_key( + app: tauri::AppHandle, + api_key: String, +) -> Result<(), String> { + app.set_custom_api_key(api_key).map_err(|e| e.to_string()) +} + +#[tauri::command] +#[specta::specta] +pub fn get_provider( + app: tauri::AppHandle, +) -> Result { + app.get_provider().map_err(|e| e.to_string()) +} + +#[tauri::command] +#[specta::specta] +pub async fn set_provider( + app: tauri::AppHandle, + provider: crate::Provider, +) -> Result<(), String> { + app.set_provider(provider).await.map_err(|e| e.to_string()) +} + +#[tauri::command] +#[specta::specta] +pub fn get_custom_model( + app: tauri::AppHandle, +) -> Result, String> { + app.get_custom_model().map_err(|e| e.to_string()) +} + +#[tauri::command] +#[specta::specta] +pub fn set_custom_model( + app: tauri::AppHandle, + model: SupportedSttModel, +) -> Result<(), String> { + app.set_custom_model(model).map_err(|e| e.to_string()) +} diff --git a/plugins/local-stt/src/ext.rs b/plugins/local-stt/src/ext.rs index 02d9e78e3..163d64f4a 100644 --- a/plugins/local-stt/src/ext.rs +++ b/plugins/local-stt/src/ext.rs @@ -12,14 +12,22 @@ use tokio_util::sync::CancellationToken; use crate::{ model::SupportedSttModel, server::{external, internal, ServerHealth, ServerType}, - Connection, + Connection, Provider, StoreKey, }; pub trait LocalSttPluginExt { - fn local_stt_store(&self) -> tauri_plugin_store2::ScopedStore; + fn local_stt_store(&self) -> tauri_plugin_store2::ScopedStore; + fn models_dir(&self) -> PathBuf; fn list_ggml_backends(&self) -> Vec; + fn get_custom_base_url(&self) -> Result; + fn set_custom_base_url(&self, base_url: impl Into) -> Result<(), crate::Error>; + fn get_custom_api_key(&self) -> Result, crate::Error>; + fn set_custom_api_key(&self, api_key: impl Into) -> Result<(), crate::Error>; + fn get_provider(&self) -> Result; + fn set_provider(&self, provider: Provider) -> impl Future>; + fn get_connection(&self) -> impl Future>; fn start_server( @@ -34,12 +42,15 @@ pub trait LocalSttPluginExt { &self, ) -> impl Future, crate::Error>>; - fn get_current_model(&self) -> Result; - fn set_current_model( + fn get_local_model(&self) -> Result; + fn set_local_model( &self, model: SupportedSttModel, ) -> impl Future>; + fn get_custom_model(&self) -> Result, crate::Error>; + fn set_custom_model(&self, model: SupportedSttModel) -> Result<(), crate::Error>; + fn download_model( &self, model: SupportedSttModel, @@ -54,7 +65,7 @@ pub trait LocalSttPluginExt { } impl> LocalSttPluginExt for T { - fn local_stt_store(&self) -> tauri_plugin_store2::ScopedStore { + fn local_stt_store(&self) -> tauri_plugin_store2::ScopedStore { self.scoped_store(crate::PLUGIN_NAME).unwrap() } @@ -66,65 +77,137 @@ impl> LocalSttPluginExt for T { hypr_whisper_local::list_ggml_backends() } - async fn get_connection(&self) -> Result { - let model = self.get_current_model()?; + fn get_custom_base_url(&self) -> Result { + let store = self.local_stt_store(); + let v = store.get(StoreKey::CustomBaseUrl)?; + Ok(v.unwrap_or_default()) + } - match model { - SupportedSttModel::Am(_) => { - let existing_api_base = { - let state = self.state::(); - let guard = state.lock().await; - guard.external_server.as_ref().map(|s| s.base_url.clone()) - }; + fn get_custom_api_key(&self) -> Result, crate::Error> { + let store = self.local_stt_store(); + let v = store.get(StoreKey::CustomApiKey)?; + Ok(v) + } - let am_key = { - let state = self.state::(); - let key = state.lock().await.am_api_key.clone(); - key.clone().ok_or(crate::Error::AmApiKeyNotSet)? - }; + fn get_provider(&self) -> Result { + let store = self.local_stt_store(); + let v = store.get(StoreKey::Provider)?; + Ok(v.unwrap_or(Provider::Local)) + } - let conn = match existing_api_base { - Some(api_base) => Connection { - base_url: api_base, - api_key: Some(am_key), - }, - None => { - let api_base = self.start_server(Some(model)).await?; - Connection { - base_url: api_base, - api_key: Some(am_key), - } - } - }; - Ok(conn) - } - SupportedSttModel::Whisper(_) => { - let existing_api_base = { - let state = self.state::(); - let guard = state.lock().await; - guard.internal_server.as_ref().map(|s| s.base_url.clone()) - }; + fn set_custom_base_url(&self, base_url: impl Into) -> Result<(), crate::Error> { + let store = self.local_stt_store(); + store.set(StoreKey::CustomBaseUrl, base_url.into())?; + Ok(()) + } - let conn = match existing_api_base { - Some(api_base) => Connection { - base_url: api_base, - api_key: None, - }, - None => { - let api_base = self.start_server(Some(model)).await?; - Connection { - base_url: api_base, - api_key: None, - } + fn set_custom_api_key(&self, api_key: impl Into) -> Result<(), crate::Error> { + let store = self.local_stt_store(); + store.set(StoreKey::CustomApiKey, api_key.into())?; + Ok(()) + } + + async fn set_provider(&self, provider: Provider) -> Result<(), crate::Error> { + let store = self.local_stt_store(); + store.set(StoreKey::Provider, &provider)?; + + if matches!(provider, Provider::Local) { + let local_model = self.get_local_model()?; + self.start_server(Some(local_model)).await?; + } + + Ok(()) + } + + async fn get_connection(&self) -> Result { + let provider = self.get_provider()?; + + match provider { + Provider::Custom => { + let model = self.get_custom_model()?; + let base_url = self.get_custom_base_url()?; + let api_key = self.get_custom_api_key()?; + Ok(Connection { + model: model.map(|m| m.to_string()), + base_url, + api_key, + }) + } + Provider::Local => { + let model = self.get_local_model()?; + + match model { + SupportedSttModel::Custom(_) => { + let base_url = self.get_custom_base_url()?; + let api_key = self.get_custom_api_key()?; + Ok(Connection { + model: None, + base_url, + api_key, + }) } - }; - Ok(conn) + SupportedSttModel::Am(_) => { + let existing_api_base = { + let state = self.state::(); + let guard = state.lock().await; + guard.external_server.as_ref().map(|s| s.base_url.clone()) + }; + + let am_key = { + let state = self.state::(); + let key = state.lock().await.am_api_key.clone(); + key.clone().ok_or(crate::Error::AmApiKeyNotSet)? + }; + + let conn = match existing_api_base { + Some(api_base) => Connection { + model: None, + base_url: api_base, + api_key: Some(am_key), + }, + None => { + let api_base = self.start_server(Some(model)).await?; + Connection { + model: None, + base_url: api_base, + api_key: Some(am_key), + } + } + }; + Ok(conn) + } + SupportedSttModel::Whisper(_) => { + let existing_api_base = { + let state = self.state::(); + let guard = state.lock().await; + guard.internal_server.as_ref().map(|s| s.base_url.clone()) + }; + + let conn = match existing_api_base { + Some(api_base) => Connection { + model: None, + base_url: api_base, + api_key: None, + }, + None => { + let api_base = self.start_server(Some(model)).await?; + Connection { + model: None, + base_url: api_base, + api_key: None, + } + } + }; + Ok(conn) + } + } } } } async fn is_model_downloaded(&self, model: &SupportedSttModel) -> Result { match model { + SupportedSttModel::Custom(_) => Ok(false), SupportedSttModel::Am(model) => Ok(model.is_downloaded(self.models_dir())?), SupportedSttModel::Whisper(model) => { let model_path = self.models_dir().join(model.file_name()); @@ -147,12 +230,21 @@ impl> LocalSttPluginExt for T { #[tracing::instrument(skip_all)] async fn start_server(&self, model: Option) -> Result { + let provider = self.get_provider()?; + + if matches!(provider, Provider::Custom) { + return self.get_custom_base_url(); + } + let model = match model { Some(m) => m, - None => self.get_current_model()?, + None => self.get_local_model()?, }; let t = match &model { + SupportedSttModel::Custom(_) => { + return Err(crate::Error::UnsupportedModelType); + } SupportedSttModel::Am(_) => ServerType::External, SupportedSttModel::Whisper(_) => ServerType::Internal, }; @@ -161,6 +253,7 @@ impl> LocalSttPluginExt for T { let data_dir = self.app_handle().path().app_data_dir().unwrap().join("stt"); match t { + ServerType::Custom => Ok("".to_string()), ServerType::Internal => { if !self.is_model_downloaded(&model).await? { return Err(crate::Error::ModelNotDownloaded); @@ -178,7 +271,7 @@ impl> LocalSttPluginExt for T { let whisper_model = match model { SupportedSttModel::Whisper(m) => m, - SupportedSttModel::Am(_) => { + _ => { return Err(crate::Error::UnsupportedModelType); } }; @@ -213,7 +306,7 @@ impl> LocalSttPluginExt for T { let am_model = match model { SupportedSttModel::Am(m) => m, - SupportedSttModel::Whisper(_) => { + _ => { return Err(crate::Error::UnsupportedModelType); } }; @@ -273,6 +366,12 @@ impl> LocalSttPluginExt for T { #[tracing::instrument(skip_all)] async fn stop_server(&self, server_type: Option) -> Result { + let provider = self.get_provider()?; + + if matches!(provider, Provider::Custom) { + return Ok(false); + } + let state = self.state::(); let mut s = state.lock().await; @@ -290,6 +389,7 @@ impl> LocalSttPluginExt for T { stopped = true; } } + Some(ServerType::Custom) => {} None => { if let Some(_) = s.external_server.take() { stopped = true; @@ -308,22 +408,43 @@ impl> LocalSttPluginExt for T { let state = self.state::(); let guard = state.lock().await; - let internal_url = if let Some(server) = &guard.internal_server { + let internal_health = if let Some(server) = &guard.internal_server { let status = server.health().await; status } else { ServerHealth::Unreachable }; - let external_url = if let Some(server) = &guard.external_server { + let external_health = if let Some(server) = &guard.external_server { server.health().await } else { ServerHealth::Unreachable }; + let custom_health = { + let provider = self.get_provider()?; + if matches!(provider, Provider::Custom) { + let base_url = self.get_custom_base_url()?; + if !base_url.is_empty() { + let client = reqwest::Client::new(); + let url = format!("{}/v1/status", base_url.trim_end_matches('/')); + + match client.get(&url).send().await { + Ok(response) if response.status().as_u16() == 204 => ServerHealth::Ready, + _ => ServerHealth::Unreachable, + } + } else { + ServerHealth::Unreachable + } + } else { + ServerHealth::Unreachable + } + }; + Ok([ - (ServerType::Internal, internal_url), - (ServerType::External, external_url), + (ServerType::Internal, internal_health), + (ServerType::External, external_health), + (ServerType::Custom, custom_health), ] .into_iter() .collect()) @@ -335,6 +456,16 @@ impl> LocalSttPluginExt for T { model: SupportedSttModel, channel: Channel, ) -> Result<(), crate::Error> { + let provider = self.get_provider()?; + + if matches!(provider, Provider::Custom) { + return Err(crate::Error::UnsupportedModelType); + } + + if let SupportedSttModel::Custom(_) = model { + return Err(crate::Error::UnsupportedModelType); + } + { let existing = { let state = self.state::(); @@ -365,6 +496,9 @@ impl> LocalSttPluginExt for T { }; match model.clone() { + SupportedSttModel::Custom(_) => { + return Err(crate::Error::UnsupportedModelType); + } SupportedSttModel::Am(m) => { let tar_path = self.models_dir().join(format!("{}.tar", m.model_dir())); let final_path = self.models_dir(); @@ -450,8 +584,13 @@ impl> LocalSttPluginExt for T { #[tracing::instrument(skip_all)] async fn is_model_downloading(&self, model: &SupportedSttModel) -> bool { - let state = self.state::(); + let provider = self.get_provider().unwrap_or(Provider::Local); + + if matches!(provider, Provider::Custom) { + return false; + } + let state = self.state::(); { let guard = state.lock().await; guard.download_task.contains_key(model) @@ -459,18 +598,38 @@ impl> LocalSttPluginExt for T { } #[tracing::instrument(skip_all)] - fn get_current_model(&self) -> Result { + fn get_local_model(&self) -> Result { let store = self.local_stt_store(); - let model = store.get(crate::StoreKey::DefaultModel)?; + let model = store.get(crate::StoreKey::LocalModel)?; Ok(model.unwrap_or(SupportedSttModel::Whisper(WhisperModel::QuantizedSmall))) } #[tracing::instrument(skip_all)] - async fn set_current_model(&self, model: SupportedSttModel) -> Result<(), crate::Error> { + async fn set_local_model(&self, model: SupportedSttModel) -> Result<(), crate::Error> { + let store = self.local_stt_store(); + store.set(crate::StoreKey::LocalModel, model.clone())?; + + let provider = self.get_provider()?; + + if matches!(provider, Provider::Local) { + self.stop_server(None).await?; + self.start_server(Some(model)).await?; + } + + Ok(()) + } + + #[tracing::instrument(skip_all)] + fn get_custom_model(&self) -> Result, crate::Error> { + let store = self.local_stt_store(); + let model = store.get(crate::StoreKey::CustomModel)?; + Ok(model) + } + + #[tracing::instrument(skip_all)] + fn set_custom_model(&self, model: SupportedSttModel) -> Result<(), crate::Error> { let store = self.local_stt_store(); - store.set(crate::StoreKey::DefaultModel, model.clone())?; - self.stop_server(None).await?; - self.start_server(Some(model)).await?; + store.set(crate::StoreKey::CustomModel, model)?; Ok(()) } } diff --git a/plugins/local-stt/src/lib.rs b/plugins/local-stt/src/lib.rs index e7bcea41d..bacff0613 100644 --- a/plugins/local-stt/src/lib.rs +++ b/plugins/local-stt/src/lib.rs @@ -39,13 +39,21 @@ fn make_specta_builder() -> tauri_specta::Builder { commands::is_model_downloaded::, commands::is_model_downloading::, commands::download_model::, - commands::get_current_model::, - commands::set_current_model::, + commands::get_local_model::, + commands::set_local_model::, commands::get_servers::, commands::start_server::, commands::stop_server::, commands::list_supported_models, commands::list_supported_languages, + commands::get_custom_base_url::, + commands::get_custom_api_key::, + commands::set_custom_base_url::, + commands::set_custom_api_key::, + commands::get_provider::, + commands::set_provider::, + commands::get_custom_model::, + commands::set_custom_model::, ]) .typ::() .error_handling(tauri_specta::ErrorHandlingMode::Throw) @@ -139,7 +147,7 @@ mod test { // cargo test test_local_stt -p tauri-plugin-local-stt -- --ignored --nocapture async fn test_local_stt() { let app = create_app(tauri::test::mock_builder()); - let model = app.get_current_model(); + let model = app.get_local_model(); println!("model: {:#?}", model); } } diff --git a/plugins/local-stt/src/model.rs b/plugins/local-stt/src/model.rs index 90ec51535..f44a2391f 100644 --- a/plugins/local-stt/src/model.rs +++ b/plugins/local-stt/src/model.rs @@ -25,6 +25,18 @@ pub struct SttModelInfo { pub enum SupportedSttModel { Whisper(WhisperModel), Am(AmModel), + // must be the last item + Custom(String), +} + +impl std::fmt::Display for SupportedSttModel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SupportedSttModel::Whisper(model) => write!(f, "whisper-{}", model), + SupportedSttModel::Am(model) => write!(f, "am-{}", model), + SupportedSttModel::Custom(model) => write!(f, "{}", model), + } + } } impl SupportedSttModel { @@ -179,6 +191,7 @@ impl SupportedSttModel { hypr_am::AmModel::ParakeetV3 => parakeet_v3_languages, hypr_am::AmModel::WhisperLargeV3 => whisper_multi_languages, }, + SupportedSttModel::Custom(_) => vec![], } } @@ -194,6 +207,11 @@ impl SupportedSttModel { display_name: model.display_name().to_string(), size_bytes: model.model_size_bytes(), }, + SupportedSttModel::Custom(_) => SttModelInfo { + key: self.clone(), + display_name: "Custom".to_string(), + size_bytes: 0, + }, } } } diff --git a/plugins/local-stt/src/server/mod.rs b/plugins/local-stt/src/server/mod.rs index e57ee6892..2383cbb37 100644 --- a/plugins/local-stt/src/server/mod.rs +++ b/plugins/local-stt/src/server/mod.rs @@ -9,6 +9,8 @@ pub enum ServerType { Internal, #[serde(rename = "external")] External, + #[serde(rename = "custom")] + Custom, } #[derive( diff --git a/plugins/local-stt/src/store.rs b/plugins/local-stt/src/store.rs index 384a5bbd4..4aea499f7 100644 --- a/plugins/local-stt/src/store.rs +++ b/plugins/local-stt/src/store.rs @@ -1,8 +1,24 @@ use tauri_plugin_store2::ScopedStoreKey; -#[derive(serde::Deserialize, specta::Type, PartialEq, Eq, Hash, strum::Display)] +#[derive( + serde::Deserialize, serde::Serialize, specta::Type, PartialEq, Eq, Hash, strum::Display, +)] pub enum StoreKey { - DefaultModel, + Provider, + #[serde(rename = "DefaultModel")] // for backward compatibility + #[strum(serialize = "DefaultModel")] + LocalModel, + CustomModel, + CustomBaseUrl, + CustomApiKey, +} + +#[derive( + serde::Deserialize, serde::Serialize, specta::Type, PartialEq, Eq, Hash, strum::Display, +)] +pub enum Provider { + Local, + Custom, } impl ScopedStoreKey for StoreKey {} diff --git a/plugins/local-stt/src/types.rs b/plugins/local-stt/src/types.rs index 2cba2e122..a58dc22df 100644 --- a/plugins/local-stt/src/types.rs +++ b/plugins/local-stt/src/types.rs @@ -1,5 +1,6 @@ #[derive(Debug)] pub struct Connection { + pub model: Option, pub base_url: String, pub api_key: Option, }