diff --git a/apps/desktop/src/components/settings/ai/llm/select.tsx b/apps/desktop/src/components/settings/ai/llm/select.tsx index 404dbe514c..42f64f4ccf 100644 --- a/apps/desktop/src/components/settings/ai/llm/select.tsx +++ b/apps/desktop/src/components/settings/ai/llm/select.tsx @@ -15,7 +15,10 @@ import { useBillingAccess } from "../../../../billing"; import { useConfigValues } from "../../../../config/use-config"; import * as main from "../../../../store/tinybase/main"; import { listAnthropicModels } from "../shared/list-anthropic"; -import type { ListModelsResult } from "../shared/list-common"; +import { + type InputModality, + type ListModelsResult, +} from "../shared/list-common"; import { listGoogleModels } from "../shared/list-google"; import { listLMStudioModels } from "../shared/list-lmstudio"; import { listOllamaModels } from "../shared/list-ollama"; @@ -143,12 +146,7 @@ export function SelectProviderAndModel() { const providerRequiresPro = providerDef?.requiresPro ?? false; const locked = providerRequiresPro && !billing.isPro; - const listModels = () => { - if (!maybeListModels || locked) { - return { models: [], ignored: [] }; - } - return maybeListModels(); - }; + const listModels = !locked ? maybeListModels : undefined; return (
@@ -181,7 +179,7 @@ export function SelectProviderAndModel() { function useConfiguredMapping(): Record< string, - null | (() => Promise) + undefined | (() => Promise) > { const auth = useAuth(); const billing = useBillingAccess(); @@ -194,25 +192,35 @@ function useConfiguredMapping(): Record< return Object.fromEntries( PROVIDERS.map((provider) => { if (provider.requiresPro && !billing.isPro) { - return [provider.id, null]; + return [provider.id, undefined]; } if (provider.id === "hyprnote") { if (!auth?.session) { - return [provider.id, null]; + return [provider.id, undefined]; } - return [provider.id, async () => ({ models: ["Auto"], ignored: [] })]; + const result: ListModelsResult = { + models: ["Auto"], + ignored: [], + metadata: { + Auto: { + input_modalities: ["text", "image"] as InputModality[], + }, + }, + }; + + return [provider.id, async () => result]; } const config = configuredProviders[provider.id]; if (!config || !config.base_url) { - return [provider.id, null]; + return [provider.id, undefined]; } if (provider.apiKey && !config.api_key) { - return [provider.id, null]; + return [provider.id, undefined]; } const { base_url, api_key } = config; @@ -249,7 +257,7 @@ function useConfiguredMapping(): Record< return [provider.id, listModelsFunc]; }), - ) as Record Promise)>; + ) as Record Promise)>; }, [configuredProviders, auth, billing]); return mapping; diff --git a/apps/desktop/src/components/settings/ai/shared/list-anthropic.ts b/apps/desktop/src/components/settings/ai/shared/list-anthropic.ts index 7d8dece637..2c6e2e11f6 100644 --- a/apps/desktop/src/components/settings/ai/shared/list-anthropic.ts +++ b/apps/desktop/src/components/settings/ai/shared/list-anthropic.ts @@ -2,7 +2,9 @@ import { Effect, pipe, Schema } from "effect"; import { DEFAULT_RESULT, + extractMetadataMap, fetchJson, + type InputModality, type ListModelsResult, type ModelIgnoreReason, partition, @@ -39,8 +41,8 @@ export async function listAnthropicModels( "anthropic-dangerous-direct-browser-access": "true", }), Effect.andThen((json) => Schema.decodeUnknown(AnthropicModelSchema)(json)), - Effect.map(({ data }) => - partition( + Effect.map(({ data }) => ({ + ...partition( data, (model) => { const reasons: ModelIgnoreReason[] = []; @@ -51,9 +53,18 @@ export async function listAnthropicModels( }, (model) => model.id, ), - ), + metadata: extractMetadataMap( + data, + (model) => model.id, + (model) => ({ input_modalities: getInputModalities(model.id) }), + ), + })), Effect.timeout(REQUEST_TIMEOUT), Effect.catchAll(() => Effect.succeed(DEFAULT_RESULT)), Effect.runPromise, ); } + +const getInputModalities = (_modelId: string): InputModality[] => { + return ["text", "image"]; +}; diff --git a/apps/desktop/src/components/settings/ai/shared/list-common.ts b/apps/desktop/src/components/settings/ai/shared/list-common.ts index 6ebd51b5cf..722792c916 100644 --- a/apps/desktop/src/components/settings/ai/shared/list-common.ts +++ b/apps/desktop/src/components/settings/ai/shared/list-common.ts @@ -11,12 +11,23 @@ export type ModelIgnoreReason = export type IgnoredModel = { id: string; reasons: ModelIgnoreReason[] }; +export type InputModality = "image" | "text"; + +export type ModelMetadata = { + input_modalities?: InputModality[]; +}; + export type ListModelsResult = { models: string[]; ignored: IgnoredModel[]; + metadata: Record; }; -export const DEFAULT_RESULT: ListModelsResult = { models: [], ignored: [] }; +export const DEFAULT_RESULT: ListModelsResult = { + models: [], + ignored: [], + metadata: {}, +}; export const REQUEST_TIMEOUT = "5 seconds"; export const commonIgnoreKeywords = [ @@ -49,20 +60,52 @@ export const shouldIgnoreCommonKeywords = (id: string): boolean => { return commonIgnoreKeywords.some((keyword) => lowerId.includes(keyword)); }; +const hasMetadata = (metadata: ModelMetadata | undefined): boolean => { + if (!metadata) { + return false; + } + if (metadata.input_modalities && metadata.input_modalities.length > 0) { + return true; + } + return false; +}; + export const partition = ( items: readonly T[], shouldIgnore: (item: T) => ModelIgnoreReason[] | null, extract: (item: T) => string, -): ListModelsResult => { - const result = { models: [] as string[], ignored: [] as IgnoredModel[] }; +): { models: string[]; ignored: IgnoredModel[] } => { + const models: string[] = []; + const ignored: IgnoredModel[] = []; + for (const item of items) { const reasons = shouldIgnore(item); + const id = extract(item); if (!reasons || reasons.length === 0) { - result.models.push(extract(item)); + models.push(id); } else { - result.ignored.push({ id: extract(item), reasons }); + ignored.push({ id, reasons }); } } - return result; + + return { models, ignored }; +}; + +export const extractMetadataMap = ( + items: readonly T[], + extract: (item: T) => string, + extractMetadata: (item: T) => ModelMetadata | undefined, +): Record => { + const metadata: Record = {}; + + for (const item of items) { + const id = extract(item); + const meta = extractMetadata(item); + if (hasMetadata(meta)) { + metadata[id] = meta!; + } + } + + return metadata; }; diff --git a/apps/desktop/src/components/settings/ai/shared/list-google.ts b/apps/desktop/src/components/settings/ai/shared/list-google.ts index c2f133246b..932cb1d09a 100644 --- a/apps/desktop/src/components/settings/ai/shared/list-google.ts +++ b/apps/desktop/src/components/settings/ai/shared/list-google.ts @@ -2,7 +2,9 @@ import { Effect, pipe, Schema } from "effect"; import { DEFAULT_RESULT, + extractMetadataMap, fetchJson, + type InputModality, type ListModelsResult, type ModelIgnoreReason, partition, @@ -53,11 +55,20 @@ export async function listGoogleModels( return pipe( fetchJson(`${baseUrl}/models`, { "x-goog-api-key": apiKey }), Effect.andThen((json) => Schema.decodeUnknown(GoogleModelSchema)(json)), - Effect.map(({ models }) => - partition(models, getIgnoreReasons, extractModelId), - ), + Effect.map(({ models }) => ({ + ...partition(models, getIgnoreReasons, extractModelId), + metadata: extractMetadataMap(models, extractModelId, (model) => ({ + input_modalities: getInputModalities(extractModelId(model)), + })), + })), Effect.timeout(REQUEST_TIMEOUT), Effect.catchAll(() => Effect.succeed(DEFAULT_RESULT)), Effect.runPromise, ); } + +const getInputModalities = (modelId: string): InputModality[] => { + const normalizedId = modelId.toLowerCase(); + const supportsMultimodal = /gemini/.test(normalizedId); + return supportsMultimodal ? ["text", "image"] : ["text"]; +}; diff --git a/apps/desktop/src/components/settings/ai/shared/list-lmstudio.ts b/apps/desktop/src/components/settings/ai/shared/list-lmstudio.ts index 8b0109f2f2..c4322145ae 100644 --- a/apps/desktop/src/components/settings/ai/shared/list-lmstudio.ts +++ b/apps/desktop/src/components/settings/ai/shared/list-lmstudio.ts @@ -6,6 +6,7 @@ import { type IgnoredModel, type ListModelsResult, type ModelIgnoreReason, + type ModelMetadata, REQUEST_TIMEOUT, } from "./list-common"; @@ -74,6 +75,7 @@ const processLMStudioModels = ( ): ListModelsResult => { const models: string[] = []; const ignored: IgnoredModel[] = []; + const metadata: Record = {}; for (const model of downloadedModels) { const reasons: ModelIgnoreReason[] = []; @@ -91,6 +93,8 @@ const processLMStudioModels = ( if (reasons.length === 0) { models.push(model.path); + // TODO: Seems like LMStudio do not have way to know input modality. + metadata[model.path] = { input_modalities: ["text"] }; } else { ignored.push({ id: model.path, reasons }); } @@ -107,5 +111,5 @@ const processLMStudioModels = ( return aLoaded ? -1 : 1; }); - return { models, ignored }; + return { models, ignored, metadata }; }; diff --git a/apps/desktop/src/components/settings/ai/shared/list-ollama.ts b/apps/desktop/src/components/settings/ai/shared/list-ollama.ts index d24bf31245..3f4de70c15 100644 --- a/apps/desktop/src/components/settings/ai/shared/list-ollama.ts +++ b/apps/desktop/src/components/settings/ai/shared/list-ollama.ts @@ -6,6 +6,7 @@ import { type IgnoredModel, type ListModelsResult, type ModelIgnoreReason, + type ModelMetadata, REQUEST_TIMEOUT, } from "./list-common"; @@ -90,6 +91,7 @@ const summarizeOllamaDetails = ( ): ListModelsResult => { const supported: Array<{ name: string; isRunning: boolean }> = []; const ignored: IgnoredModel[] = []; + const metadata: Record = {}; for (const detail of details) { const hasCompletion = detail.capabilities.includes("completion"); @@ -97,6 +99,8 @@ const summarizeOllamaDetails = ( if (hasCompletion && hasTools) { supported.push({ name: detail.name, isRunning: detail.isRunning }); + // TODO: Seems like Ollama do not have way to know input modality. + metadata[detail.name] = { input_modalities: ["text"] }; } else { const reasons: ModelIgnoreReason[] = []; if (!hasCompletion) { @@ -119,5 +123,6 @@ const summarizeOllamaDetails = ( return { models: supported.map((detail) => detail.name), ignored, + metadata, }; }; diff --git a/apps/desktop/src/components/settings/ai/shared/list-openai.ts b/apps/desktop/src/components/settings/ai/shared/list-openai.ts index 388556c2ba..b880adedf2 100644 --- a/apps/desktop/src/components/settings/ai/shared/list-openai.ts +++ b/apps/desktop/src/components/settings/ai/shared/list-openai.ts @@ -2,6 +2,7 @@ import { Effect, pipe, Schema } from "effect"; import { DEFAULT_RESULT, + extractMetadataMap, fetchJson, type ListModelsResult, type ModelIgnoreReason, @@ -29,8 +30,8 @@ export async function listOpenAIModels( return pipe( fetchJson(`${baseUrl}/models`, { Authorization: `Bearer ${apiKey}` }), Effect.andThen((json) => Schema.decodeUnknown(OpenAIModelSchema)(json)), - Effect.map(({ data }) => - partition( + Effect.map(({ data }) => ({ + ...partition( data, (model) => { const reasons: ModelIgnoreReason[] = []; @@ -41,7 +42,12 @@ export async function listOpenAIModels( }, (model) => model.id, ), - ), + metadata: extractMetadataMap( + data, + (model) => model.id, + (_model) => ({ input_modalities: ["text", "image"] }), + ), + })), Effect.timeout(REQUEST_TIMEOUT), Effect.catchAll(() => Effect.succeed(DEFAULT_RESULT)), Effect.runPromise, @@ -59,8 +65,8 @@ export async function listGenericModels( return pipe( fetchJson(`${baseUrl}/models`, { Authorization: `Bearer ${apiKey}` }), Effect.andThen((json) => Schema.decodeUnknown(OpenAIModelSchema)(json)), - Effect.map(({ data }) => - partition( + Effect.map(({ data }) => ({ + ...partition( data, (model) => { const reasons: ModelIgnoreReason[] = []; @@ -71,7 +77,12 @@ export async function listGenericModels( }, (model) => model.id, ), - ), + metadata: extractMetadataMap( + data, + (model) => model.id, + () => ({ input_modalities: ["text"] }), + ), + })), Effect.timeout(REQUEST_TIMEOUT), Effect.catchAll(() => Effect.succeed(DEFAULT_RESULT)), Effect.runPromise, diff --git a/apps/desktop/src/components/settings/ai/shared/list-openrouter.ts b/apps/desktop/src/components/settings/ai/shared/list-openrouter.ts index 0fde257922..d68f2df05e 100644 --- a/apps/desktop/src/components/settings/ai/shared/list-openrouter.ts +++ b/apps/desktop/src/components/settings/ai/shared/list-openrouter.ts @@ -2,7 +2,9 @@ import { Effect, pipe, Schema } from "effect"; import { DEFAULT_RESULT, + extractMetadataMap, fetchJson, + type InputModality, type ListModelsResult, type ModelIgnoreReason, partition, @@ -69,11 +71,29 @@ export async function listOpenRouterModels( return pipe( fetchJson(`${baseUrl}/models`, { Authorization: `Bearer ${apiKey}` }), Effect.andThen((json) => Schema.decodeUnknown(OpenRouterModelSchema)(json)), - Effect.map(({ data }) => - partition(data, getIgnoreReasons, (model) => model.id), - ), + Effect.map(({ data }) => ({ + ...partition(data, getIgnoreReasons, (model) => model.id), + metadata: extractMetadataMap( + data, + (model) => model.id, + (model) => ({ input_modalities: getInputModalities(model) }), + ), + })), Effect.timeout(REQUEST_TIMEOUT), Effect.catchAll(() => Effect.succeed(DEFAULT_RESULT)), Effect.runPromise, ); } + +const getInputModalities = (model: OpenRouterModel): InputModality[] => { + const modalities = model.architecture?.input_modalities ?? []; + + return [ + ...((modalities.includes("text") + ? ["text"] + : []) satisfies InputModality[]), + ...((modalities.includes("image") + ? ["image"] + : []) satisfies InputModality[]), + ]; +}; diff --git a/apps/desktop/src/components/settings/ai/shared/model-combobox.tsx b/apps/desktop/src/components/settings/ai/shared/model-combobox.tsx index 2d957f5979..49788fc141 100644 --- a/apps/desktop/src/components/settings/ai/shared/model-combobox.tsx +++ b/apps/desktop/src/components/settings/ai/shared/model-combobox.tsx @@ -1,4 +1,3 @@ -import { useQuery } from "@tanstack/react-query"; import { ChevronDown, CirclePlus, Eye, EyeOff } from "lucide-react"; import { useCallback, useMemo, useState } from "react"; @@ -23,6 +22,7 @@ import { } from "@hypr/ui/components/ui/tooltip"; import { cn } from "@hypr/utils"; +import { useModelMetadata } from "../../../../hooks/useModelMetadata"; import type { ListModelsResult, ModelIgnoreReason } from "./list-common"; const filterFunction = (value: string, search: string) => { @@ -62,7 +62,7 @@ export function ModelCombobox({ providerId: string; value: string; onChange: (value: string) => void; - listModels: () => Promise | ListModelsResult; + listModels?: () => Promise | ListModelsResult; disabled?: boolean; placeholder?: string; }) { @@ -70,12 +70,11 @@ export function ModelCombobox({ const [query, setQuery] = useState(""); const [showIgnored, setShowIgnored] = useState(false); - const { data: fetchedResult, isLoading } = useQuery({ - queryKey: ["models", providerId, listModels], - queryFn: listModels, - retry: 3, - retryDelay: 300, - }); + const { data: fetchedResult, isLoading: isLoadingModels } = useModelMetadata( + providerId, + listModels, + { enabled: !disabled }, + ); const options: string[] = useMemo( () => fetchedResult?.models ?? [], @@ -118,7 +117,7 @@ export function ModelCombobox({ type="button" variant="outline" role="combobox" - disabled={disabled || isLoading} + disabled={disabled || isLoadingModels} aria-expanded={open} className={cn(["w-full justify-between font-normal bg-white"])} > @@ -126,7 +125,7 @@ export function ModelCombobox({ {value} ) : ( - {isLoading ? "Loading models..." : placeholder} + {isLoadingModels ? "Loading models..." : placeholder} )} diff --git a/apps/desktop/src/hooks/useCurrentModelModalitySupport.ts b/apps/desktop/src/hooks/useCurrentModelModalitySupport.ts new file mode 100644 index 0000000000..559bd8189c --- /dev/null +++ b/apps/desktop/src/hooks/useCurrentModelModalitySupport.ts @@ -0,0 +1,101 @@ +import { useMemo } from "react"; + +import { useAuth } from "../auth"; +import { + type ProviderId, + PROVIDERS, +} from "../components/settings/ai/llm/shared"; +import { listAnthropicModels } from "../components/settings/ai/shared/list-anthropic"; +import type { InputModality } from "../components/settings/ai/shared/list-common"; +import { listGoogleModels } from "../components/settings/ai/shared/list-google"; +import { listLMStudioModels } from "../components/settings/ai/shared/list-lmstudio"; +import { listOllamaModels } from "../components/settings/ai/shared/list-ollama"; +import { + listGenericModels, + listOpenAIModels, +} from "../components/settings/ai/shared/list-openai"; +import { listOpenRouterModels } from "../components/settings/ai/shared/list-openrouter"; +import * as main from "../store/tinybase/main"; +import { useModelMetadata } from "./useModelMetadata"; + +export function useCurrentModelModalitySupport(): InputModality[] | null { + const auth = useAuth(); + const { current_llm_provider, current_llm_model } = main.UI.useValues( + main.STORE_ID, + ); + const providerConfig = main.UI.useRow( + "ai_providers", + current_llm_provider ?? "", + main.STORE_ID, + ) as main.AIProviderStorage | undefined; + + const providerId = current_llm_provider as ProviderId | null; + const providerDef = PROVIDERS.find((provider) => provider.id === providerId); + + const listModels = useMemo(() => { + if (!providerId || !current_llm_model) { + return undefined; + } + + if (providerId === "hyprnote") { + if (!auth?.session) { + return undefined; + } + return async () => ({ + models: ["Auto"], + ignored: [], + metadata: { + Auto: { + input_modalities: ["text", "image"] as InputModality[], + }, + }, + }); + } + + const baseUrl = + providerConfig?.base_url?.trim() || providerDef?.baseUrl?.trim() || ""; + const apiKey = providerConfig?.api_key?.trim() || ""; + + if (!baseUrl || (providerDef?.apiKey && !apiKey)) { + return undefined; + } + + return getFetcher(providerId, baseUrl, apiKey); + }, [ + providerId, + current_llm_model, + auth?.session, + providerConfig?.base_url, + providerConfig?.api_key, + providerDef, + ]); + + const { data } = useModelMetadata(providerId, listModels); + + if (!current_llm_model || !data) { + return null; + } + + return data.metadata?.[current_llm_model]?.input_modalities ?? null; +} + +function getFetcher(providerId: ProviderId, baseUrl: string, apiKey: string) { + switch (providerId) { + case "openai": + return () => listOpenAIModels(baseUrl, apiKey); + case "anthropic": + return () => listAnthropicModels(baseUrl, apiKey); + case "openrouter": + return () => listOpenRouterModels(baseUrl, apiKey); + case "google_generative_ai": + return () => listGoogleModels(baseUrl, apiKey); + case "ollama": + return () => listOllamaModels(baseUrl, apiKey); + case "lmstudio": + return () => listLMStudioModels(baseUrl, apiKey); + case "custom": + return () => listGenericModels(baseUrl, apiKey); + default: + return () => listGenericModels(baseUrl, apiKey); + } +} diff --git a/apps/desktop/src/hooks/useModelMetadata.ts b/apps/desktop/src/hooks/useModelMetadata.ts new file mode 100644 index 0000000000..7333280b85 --- /dev/null +++ b/apps/desktop/src/hooks/useModelMetadata.ts @@ -0,0 +1,32 @@ +import { useQuery } from "@tanstack/react-query"; + +import { + DEFAULT_RESULT, + type ListModelsResult, +} from "../components/settings/ai/shared/list-common"; + +export function useModelMetadata( + providerId: string | null, + listModels: (() => Promise | ListModelsResult) | undefined, + options?: { + enabled?: boolean; + }, +) { + const enabled = options?.enabled ?? Boolean(providerId && listModels); + + const { data, isLoading } = useQuery({ + queryKey: ["models", providerId], + queryFn: async () => { + if (!listModels) { + return DEFAULT_RESULT; + } + return await listModels(); + }, + enabled, + retry: 3, + retryDelay: 300, + staleTime: 1000 * 60, + }); + + return { data, isLoading }; +}