diff --git a/apps/desktop/src/components/settings/ai/llm/select.tsx b/apps/desktop/src/components/settings/ai/llm/select.tsx
index 404dbe514c..42f64f4ccf 100644
--- a/apps/desktop/src/components/settings/ai/llm/select.tsx
+++ b/apps/desktop/src/components/settings/ai/llm/select.tsx
@@ -15,7 +15,10 @@ import { useBillingAccess } from "../../../../billing";
import { useConfigValues } from "../../../../config/use-config";
import * as main from "../../../../store/tinybase/main";
import { listAnthropicModels } from "../shared/list-anthropic";
-import type { ListModelsResult } from "../shared/list-common";
+import {
+ type InputModality,
+ type ListModelsResult,
+} from "../shared/list-common";
import { listGoogleModels } from "../shared/list-google";
import { listLMStudioModels } from "../shared/list-lmstudio";
import { listOllamaModels } from "../shared/list-ollama";
@@ -143,12 +146,7 @@ export function SelectProviderAndModel() {
const providerRequiresPro = providerDef?.requiresPro ?? false;
const locked = providerRequiresPro && !billing.isPro;
- const listModels = () => {
- if (!maybeListModels || locked) {
- return { models: [], ignored: [] };
- }
- return maybeListModels();
- };
+ const listModels = !locked ? maybeListModels : undefined;
return (
@@ -181,7 +179,7 @@ export function SelectProviderAndModel() {
function useConfiguredMapping(): Record<
string,
- null | (() => Promise)
+ undefined | (() => Promise)
> {
const auth = useAuth();
const billing = useBillingAccess();
@@ -194,25 +192,35 @@ function useConfiguredMapping(): Record<
return Object.fromEntries(
PROVIDERS.map((provider) => {
if (provider.requiresPro && !billing.isPro) {
- return [provider.id, null];
+ return [provider.id, undefined];
}
if (provider.id === "hyprnote") {
if (!auth?.session) {
- return [provider.id, null];
+ return [provider.id, undefined];
}
- return [provider.id, async () => ({ models: ["Auto"], ignored: [] })];
+ const result: ListModelsResult = {
+ models: ["Auto"],
+ ignored: [],
+ metadata: {
+ Auto: {
+ input_modalities: ["text", "image"] as InputModality[],
+ },
+ },
+ };
+
+ return [provider.id, async () => result];
}
const config = configuredProviders[provider.id];
if (!config || !config.base_url) {
- return [provider.id, null];
+ return [provider.id, undefined];
}
if (provider.apiKey && !config.api_key) {
- return [provider.id, null];
+ return [provider.id, undefined];
}
const { base_url, api_key } = config;
@@ -249,7 +257,7 @@ function useConfiguredMapping(): Record<
return [provider.id, listModelsFunc];
}),
- ) as Record Promise)>;
+ ) as Record Promise)>;
}, [configuredProviders, auth, billing]);
return mapping;
diff --git a/apps/desktop/src/components/settings/ai/shared/list-anthropic.ts b/apps/desktop/src/components/settings/ai/shared/list-anthropic.ts
index 7d8dece637..2c6e2e11f6 100644
--- a/apps/desktop/src/components/settings/ai/shared/list-anthropic.ts
+++ b/apps/desktop/src/components/settings/ai/shared/list-anthropic.ts
@@ -2,7 +2,9 @@ import { Effect, pipe, Schema } from "effect";
import {
DEFAULT_RESULT,
+ extractMetadataMap,
fetchJson,
+ type InputModality,
type ListModelsResult,
type ModelIgnoreReason,
partition,
@@ -39,8 +41,8 @@ export async function listAnthropicModels(
"anthropic-dangerous-direct-browser-access": "true",
}),
Effect.andThen((json) => Schema.decodeUnknown(AnthropicModelSchema)(json)),
- Effect.map(({ data }) =>
- partition(
+ Effect.map(({ data }) => ({
+ ...partition(
data,
(model) => {
const reasons: ModelIgnoreReason[] = [];
@@ -51,9 +53,18 @@ export async function listAnthropicModels(
},
(model) => model.id,
),
- ),
+ metadata: extractMetadataMap(
+ data,
+ (model) => model.id,
+ (model) => ({ input_modalities: getInputModalities(model.id) }),
+ ),
+ })),
Effect.timeout(REQUEST_TIMEOUT),
Effect.catchAll(() => Effect.succeed(DEFAULT_RESULT)),
Effect.runPromise,
);
}
+
+const getInputModalities = (_modelId: string): InputModality[] => {
+ return ["text", "image"];
+};
diff --git a/apps/desktop/src/components/settings/ai/shared/list-common.ts b/apps/desktop/src/components/settings/ai/shared/list-common.ts
index 6ebd51b5cf..722792c916 100644
--- a/apps/desktop/src/components/settings/ai/shared/list-common.ts
+++ b/apps/desktop/src/components/settings/ai/shared/list-common.ts
@@ -11,12 +11,23 @@ export type ModelIgnoreReason =
export type IgnoredModel = { id: string; reasons: ModelIgnoreReason[] };
+export type InputModality = "image" | "text";
+
+export type ModelMetadata = {
+ input_modalities?: InputModality[];
+};
+
export type ListModelsResult = {
models: string[];
ignored: IgnoredModel[];
+ metadata: Record;
};
-export const DEFAULT_RESULT: ListModelsResult = { models: [], ignored: [] };
+export const DEFAULT_RESULT: ListModelsResult = {
+ models: [],
+ ignored: [],
+ metadata: {},
+};
export const REQUEST_TIMEOUT = "5 seconds";
export const commonIgnoreKeywords = [
@@ -49,20 +60,52 @@ export const shouldIgnoreCommonKeywords = (id: string): boolean => {
return commonIgnoreKeywords.some((keyword) => lowerId.includes(keyword));
};
+const hasMetadata = (metadata: ModelMetadata | undefined): boolean => {
+ if (!metadata) {
+ return false;
+ }
+ if (metadata.input_modalities && metadata.input_modalities.length > 0) {
+ return true;
+ }
+ return false;
+};
+
export const partition = (
items: readonly T[],
shouldIgnore: (item: T) => ModelIgnoreReason[] | null,
extract: (item: T) => string,
-): ListModelsResult => {
- const result = { models: [] as string[], ignored: [] as IgnoredModel[] };
+): { models: string[]; ignored: IgnoredModel[] } => {
+ const models: string[] = [];
+ const ignored: IgnoredModel[] = [];
+
for (const item of items) {
const reasons = shouldIgnore(item);
+ const id = extract(item);
if (!reasons || reasons.length === 0) {
- result.models.push(extract(item));
+ models.push(id);
} else {
- result.ignored.push({ id: extract(item), reasons });
+ ignored.push({ id, reasons });
}
}
- return result;
+
+ return { models, ignored };
+};
+
+export const extractMetadataMap = (
+ items: readonly T[],
+ extract: (item: T) => string,
+ extractMetadata: (item: T) => ModelMetadata | undefined,
+): Record => {
+ const metadata: Record = {};
+
+ for (const item of items) {
+ const id = extract(item);
+ const meta = extractMetadata(item);
+ if (hasMetadata(meta)) {
+ metadata[id] = meta!;
+ }
+ }
+
+ return metadata;
};
diff --git a/apps/desktop/src/components/settings/ai/shared/list-google.ts b/apps/desktop/src/components/settings/ai/shared/list-google.ts
index c2f133246b..932cb1d09a 100644
--- a/apps/desktop/src/components/settings/ai/shared/list-google.ts
+++ b/apps/desktop/src/components/settings/ai/shared/list-google.ts
@@ -2,7 +2,9 @@ import { Effect, pipe, Schema } from "effect";
import {
DEFAULT_RESULT,
+ extractMetadataMap,
fetchJson,
+ type InputModality,
type ListModelsResult,
type ModelIgnoreReason,
partition,
@@ -53,11 +55,20 @@ export async function listGoogleModels(
return pipe(
fetchJson(`${baseUrl}/models`, { "x-goog-api-key": apiKey }),
Effect.andThen((json) => Schema.decodeUnknown(GoogleModelSchema)(json)),
- Effect.map(({ models }) =>
- partition(models, getIgnoreReasons, extractModelId),
- ),
+ Effect.map(({ models }) => ({
+ ...partition(models, getIgnoreReasons, extractModelId),
+ metadata: extractMetadataMap(models, extractModelId, (model) => ({
+ input_modalities: getInputModalities(extractModelId(model)),
+ })),
+ })),
Effect.timeout(REQUEST_TIMEOUT),
Effect.catchAll(() => Effect.succeed(DEFAULT_RESULT)),
Effect.runPromise,
);
}
+
+const getInputModalities = (modelId: string): InputModality[] => {
+ const normalizedId = modelId.toLowerCase();
+ const supportsMultimodal = /gemini/.test(normalizedId);
+ return supportsMultimodal ? ["text", "image"] : ["text"];
+};
diff --git a/apps/desktop/src/components/settings/ai/shared/list-lmstudio.ts b/apps/desktop/src/components/settings/ai/shared/list-lmstudio.ts
index 8b0109f2f2..c4322145ae 100644
--- a/apps/desktop/src/components/settings/ai/shared/list-lmstudio.ts
+++ b/apps/desktop/src/components/settings/ai/shared/list-lmstudio.ts
@@ -6,6 +6,7 @@ import {
type IgnoredModel,
type ListModelsResult,
type ModelIgnoreReason,
+ type ModelMetadata,
REQUEST_TIMEOUT,
} from "./list-common";
@@ -74,6 +75,7 @@ const processLMStudioModels = (
): ListModelsResult => {
const models: string[] = [];
const ignored: IgnoredModel[] = [];
+ const metadata: Record = {};
for (const model of downloadedModels) {
const reasons: ModelIgnoreReason[] = [];
@@ -91,6 +93,8 @@ const processLMStudioModels = (
if (reasons.length === 0) {
models.push(model.path);
+ // TODO: Seems like LMStudio do not have way to know input modality.
+ metadata[model.path] = { input_modalities: ["text"] };
} else {
ignored.push({ id: model.path, reasons });
}
@@ -107,5 +111,5 @@ const processLMStudioModels = (
return aLoaded ? -1 : 1;
});
- return { models, ignored };
+ return { models, ignored, metadata };
};
diff --git a/apps/desktop/src/components/settings/ai/shared/list-ollama.ts b/apps/desktop/src/components/settings/ai/shared/list-ollama.ts
index d24bf31245..3f4de70c15 100644
--- a/apps/desktop/src/components/settings/ai/shared/list-ollama.ts
+++ b/apps/desktop/src/components/settings/ai/shared/list-ollama.ts
@@ -6,6 +6,7 @@ import {
type IgnoredModel,
type ListModelsResult,
type ModelIgnoreReason,
+ type ModelMetadata,
REQUEST_TIMEOUT,
} from "./list-common";
@@ -90,6 +91,7 @@ const summarizeOllamaDetails = (
): ListModelsResult => {
const supported: Array<{ name: string; isRunning: boolean }> = [];
const ignored: IgnoredModel[] = [];
+ const metadata: Record = {};
for (const detail of details) {
const hasCompletion = detail.capabilities.includes("completion");
@@ -97,6 +99,8 @@ const summarizeOllamaDetails = (
if (hasCompletion && hasTools) {
supported.push({ name: detail.name, isRunning: detail.isRunning });
+ // TODO: Seems like Ollama do not have way to know input modality.
+ metadata[detail.name] = { input_modalities: ["text"] };
} else {
const reasons: ModelIgnoreReason[] = [];
if (!hasCompletion) {
@@ -119,5 +123,6 @@ const summarizeOllamaDetails = (
return {
models: supported.map((detail) => detail.name),
ignored,
+ metadata,
};
};
diff --git a/apps/desktop/src/components/settings/ai/shared/list-openai.ts b/apps/desktop/src/components/settings/ai/shared/list-openai.ts
index 388556c2ba..b880adedf2 100644
--- a/apps/desktop/src/components/settings/ai/shared/list-openai.ts
+++ b/apps/desktop/src/components/settings/ai/shared/list-openai.ts
@@ -2,6 +2,7 @@ import { Effect, pipe, Schema } from "effect";
import {
DEFAULT_RESULT,
+ extractMetadataMap,
fetchJson,
type ListModelsResult,
type ModelIgnoreReason,
@@ -29,8 +30,8 @@ export async function listOpenAIModels(
return pipe(
fetchJson(`${baseUrl}/models`, { Authorization: `Bearer ${apiKey}` }),
Effect.andThen((json) => Schema.decodeUnknown(OpenAIModelSchema)(json)),
- Effect.map(({ data }) =>
- partition(
+ Effect.map(({ data }) => ({
+ ...partition(
data,
(model) => {
const reasons: ModelIgnoreReason[] = [];
@@ -41,7 +42,12 @@ export async function listOpenAIModels(
},
(model) => model.id,
),
- ),
+ metadata: extractMetadataMap(
+ data,
+ (model) => model.id,
+ (_model) => ({ input_modalities: ["text", "image"] }),
+ ),
+ })),
Effect.timeout(REQUEST_TIMEOUT),
Effect.catchAll(() => Effect.succeed(DEFAULT_RESULT)),
Effect.runPromise,
@@ -59,8 +65,8 @@ export async function listGenericModels(
return pipe(
fetchJson(`${baseUrl}/models`, { Authorization: `Bearer ${apiKey}` }),
Effect.andThen((json) => Schema.decodeUnknown(OpenAIModelSchema)(json)),
- Effect.map(({ data }) =>
- partition(
+ Effect.map(({ data }) => ({
+ ...partition(
data,
(model) => {
const reasons: ModelIgnoreReason[] = [];
@@ -71,7 +77,12 @@ export async function listGenericModels(
},
(model) => model.id,
),
- ),
+ metadata: extractMetadataMap(
+ data,
+ (model) => model.id,
+ () => ({ input_modalities: ["text"] }),
+ ),
+ })),
Effect.timeout(REQUEST_TIMEOUT),
Effect.catchAll(() => Effect.succeed(DEFAULT_RESULT)),
Effect.runPromise,
diff --git a/apps/desktop/src/components/settings/ai/shared/list-openrouter.ts b/apps/desktop/src/components/settings/ai/shared/list-openrouter.ts
index 0fde257922..d68f2df05e 100644
--- a/apps/desktop/src/components/settings/ai/shared/list-openrouter.ts
+++ b/apps/desktop/src/components/settings/ai/shared/list-openrouter.ts
@@ -2,7 +2,9 @@ import { Effect, pipe, Schema } from "effect";
import {
DEFAULT_RESULT,
+ extractMetadataMap,
fetchJson,
+ type InputModality,
type ListModelsResult,
type ModelIgnoreReason,
partition,
@@ -69,11 +71,29 @@ export async function listOpenRouterModels(
return pipe(
fetchJson(`${baseUrl}/models`, { Authorization: `Bearer ${apiKey}` }),
Effect.andThen((json) => Schema.decodeUnknown(OpenRouterModelSchema)(json)),
- Effect.map(({ data }) =>
- partition(data, getIgnoreReasons, (model) => model.id),
- ),
+ Effect.map(({ data }) => ({
+ ...partition(data, getIgnoreReasons, (model) => model.id),
+ metadata: extractMetadataMap(
+ data,
+ (model) => model.id,
+ (model) => ({ input_modalities: getInputModalities(model) }),
+ ),
+ })),
Effect.timeout(REQUEST_TIMEOUT),
Effect.catchAll(() => Effect.succeed(DEFAULT_RESULT)),
Effect.runPromise,
);
}
+
+const getInputModalities = (model: OpenRouterModel): InputModality[] => {
+ const modalities = model.architecture?.input_modalities ?? [];
+
+ return [
+ ...((modalities.includes("text")
+ ? ["text"]
+ : []) satisfies InputModality[]),
+ ...((modalities.includes("image")
+ ? ["image"]
+ : []) satisfies InputModality[]),
+ ];
+};
diff --git a/apps/desktop/src/components/settings/ai/shared/model-combobox.tsx b/apps/desktop/src/components/settings/ai/shared/model-combobox.tsx
index 2d957f5979..49788fc141 100644
--- a/apps/desktop/src/components/settings/ai/shared/model-combobox.tsx
+++ b/apps/desktop/src/components/settings/ai/shared/model-combobox.tsx
@@ -1,4 +1,3 @@
-import { useQuery } from "@tanstack/react-query";
import { ChevronDown, CirclePlus, Eye, EyeOff } from "lucide-react";
import { useCallback, useMemo, useState } from "react";
@@ -23,6 +22,7 @@ import {
} from "@hypr/ui/components/ui/tooltip";
import { cn } from "@hypr/utils";
+import { useModelMetadata } from "../../../../hooks/useModelMetadata";
import type { ListModelsResult, ModelIgnoreReason } from "./list-common";
const filterFunction = (value: string, search: string) => {
@@ -62,7 +62,7 @@ export function ModelCombobox({
providerId: string;
value: string;
onChange: (value: string) => void;
- listModels: () => Promise | ListModelsResult;
+ listModels?: () => Promise | ListModelsResult;
disabled?: boolean;
placeholder?: string;
}) {
@@ -70,12 +70,11 @@ export function ModelCombobox({
const [query, setQuery] = useState("");
const [showIgnored, setShowIgnored] = useState(false);
- const { data: fetchedResult, isLoading } = useQuery({
- queryKey: ["models", providerId, listModels],
- queryFn: listModels,
- retry: 3,
- retryDelay: 300,
- });
+ const { data: fetchedResult, isLoading: isLoadingModels } = useModelMetadata(
+ providerId,
+ listModels,
+ { enabled: !disabled },
+ );
const options: string[] = useMemo(
() => fetchedResult?.models ?? [],
@@ -118,7 +117,7 @@ export function ModelCombobox({
type="button"
variant="outline"
role="combobox"
- disabled={disabled || isLoading}
+ disabled={disabled || isLoadingModels}
aria-expanded={open}
className={cn(["w-full justify-between font-normal bg-white"])}
>
@@ -126,7 +125,7 @@ export function ModelCombobox({
{value}
) : (
- {isLoading ? "Loading models..." : placeholder}
+ {isLoadingModels ? "Loading models..." : placeholder}
)}
diff --git a/apps/desktop/src/hooks/useCurrentModelModalitySupport.ts b/apps/desktop/src/hooks/useCurrentModelModalitySupport.ts
new file mode 100644
index 0000000000..559bd8189c
--- /dev/null
+++ b/apps/desktop/src/hooks/useCurrentModelModalitySupport.ts
@@ -0,0 +1,101 @@
+import { useMemo } from "react";
+
+import { useAuth } from "../auth";
+import {
+ type ProviderId,
+ PROVIDERS,
+} from "../components/settings/ai/llm/shared";
+import { listAnthropicModels } from "../components/settings/ai/shared/list-anthropic";
+import type { InputModality } from "../components/settings/ai/shared/list-common";
+import { listGoogleModels } from "../components/settings/ai/shared/list-google";
+import { listLMStudioModels } from "../components/settings/ai/shared/list-lmstudio";
+import { listOllamaModels } from "../components/settings/ai/shared/list-ollama";
+import {
+ listGenericModels,
+ listOpenAIModels,
+} from "../components/settings/ai/shared/list-openai";
+import { listOpenRouterModels } from "../components/settings/ai/shared/list-openrouter";
+import * as main from "../store/tinybase/main";
+import { useModelMetadata } from "./useModelMetadata";
+
+export function useCurrentModelModalitySupport(): InputModality[] | null {
+ const auth = useAuth();
+ const { current_llm_provider, current_llm_model } = main.UI.useValues(
+ main.STORE_ID,
+ );
+ const providerConfig = main.UI.useRow(
+ "ai_providers",
+ current_llm_provider ?? "",
+ main.STORE_ID,
+ ) as main.AIProviderStorage | undefined;
+
+ const providerId = current_llm_provider as ProviderId | null;
+ const providerDef = PROVIDERS.find((provider) => provider.id === providerId);
+
+ const listModels = useMemo(() => {
+ if (!providerId || !current_llm_model) {
+ return undefined;
+ }
+
+ if (providerId === "hyprnote") {
+ if (!auth?.session) {
+ return undefined;
+ }
+ return async () => ({
+ models: ["Auto"],
+ ignored: [],
+ metadata: {
+ Auto: {
+ input_modalities: ["text", "image"] as InputModality[],
+ },
+ },
+ });
+ }
+
+ const baseUrl =
+ providerConfig?.base_url?.trim() || providerDef?.baseUrl?.trim() || "";
+ const apiKey = providerConfig?.api_key?.trim() || "";
+
+ if (!baseUrl || (providerDef?.apiKey && !apiKey)) {
+ return undefined;
+ }
+
+ return getFetcher(providerId, baseUrl, apiKey);
+ }, [
+ providerId,
+ current_llm_model,
+ auth?.session,
+ providerConfig?.base_url,
+ providerConfig?.api_key,
+ providerDef,
+ ]);
+
+ const { data } = useModelMetadata(providerId, listModels);
+
+ if (!current_llm_model || !data) {
+ return null;
+ }
+
+ return data.metadata?.[current_llm_model]?.input_modalities ?? null;
+}
+
+function getFetcher(providerId: ProviderId, baseUrl: string, apiKey: string) {
+ switch (providerId) {
+ case "openai":
+ return () => listOpenAIModels(baseUrl, apiKey);
+ case "anthropic":
+ return () => listAnthropicModels(baseUrl, apiKey);
+ case "openrouter":
+ return () => listOpenRouterModels(baseUrl, apiKey);
+ case "google_generative_ai":
+ return () => listGoogleModels(baseUrl, apiKey);
+ case "ollama":
+ return () => listOllamaModels(baseUrl, apiKey);
+ case "lmstudio":
+ return () => listLMStudioModels(baseUrl, apiKey);
+ case "custom":
+ return () => listGenericModels(baseUrl, apiKey);
+ default:
+ return () => listGenericModels(baseUrl, apiKey);
+ }
+}
diff --git a/apps/desktop/src/hooks/useModelMetadata.ts b/apps/desktop/src/hooks/useModelMetadata.ts
new file mode 100644
index 0000000000..7333280b85
--- /dev/null
+++ b/apps/desktop/src/hooks/useModelMetadata.ts
@@ -0,0 +1,32 @@
+import { useQuery } from "@tanstack/react-query";
+
+import {
+ DEFAULT_RESULT,
+ type ListModelsResult,
+} from "../components/settings/ai/shared/list-common";
+
+export function useModelMetadata(
+ providerId: string | null,
+ listModels: (() => Promise | ListModelsResult) | undefined,
+ options?: {
+ enabled?: boolean;
+ },
+) {
+ const enabled = options?.enabled ?? Boolean(providerId && listModels);
+
+ const { data, isLoading } = useQuery({
+ queryKey: ["models", providerId],
+ queryFn: async () => {
+ if (!listModels) {
+ return DEFAULT_RESULT;
+ }
+ return await listModels();
+ },
+ enabled,
+ retry: 3,
+ retryDelay: 300,
+ staleTime: 1000 * 60,
+ });
+
+ return { data, isLoading };
+}