Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions apps/desktop/src/components/settings/ai/llm/select.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ import { useBillingAccess } from "../../../../billing";
import { useConfigValues } from "../../../../config/use-config";
import * as main from "../../../../store/tinybase/main";
import { listAnthropicModels } from "../shared/list-anthropic";
import type { ListModelsResult } from "../shared/list-common";
import {
type InputModality,
type ListModelsResult,
} from "../shared/list-common";
import { listGoogleModels } from "../shared/list-google";
import { listLMStudioModels } from "../shared/list-lmstudio";
import { listOllamaModels } from "../shared/list-ollama";
Expand Down Expand Up @@ -143,12 +146,7 @@ export function SelectProviderAndModel() {
const providerRequiresPro = providerDef?.requiresPro ?? false;
const locked = providerRequiresPro && !billing.isPro;

const listModels = () => {
if (!maybeListModels || locked) {
return { models: [], ignored: [] };
}
return maybeListModels();
};
const listModels = !locked ? maybeListModels : undefined;

return (
<div className="flex-[3] min-w-0">
Expand Down Expand Up @@ -181,7 +179,7 @@ export function SelectProviderAndModel() {

function useConfiguredMapping(): Record<
string,
null | (() => Promise<ListModelsResult>)
undefined | (() => Promise<ListModelsResult>)
> {
const auth = useAuth();
const billing = useBillingAccess();
Expand All @@ -194,25 +192,35 @@ function useConfiguredMapping(): Record<
return Object.fromEntries(
PROVIDERS.map((provider) => {
if (provider.requiresPro && !billing.isPro) {
return [provider.id, null];
return [provider.id, undefined];
}

if (provider.id === "hyprnote") {
if (!auth?.session) {
return [provider.id, null];
return [provider.id, undefined];
}

return [provider.id, async () => ({ models: ["Auto"], ignored: [] })];
const result: ListModelsResult = {
models: ["Auto"],
ignored: [],
metadata: {
Auto: {
input_modalities: ["text", "image"] as InputModality[],
},
},
};

return [provider.id, async () => result];
}

const config = configuredProviders[provider.id];

if (!config || !config.base_url) {
return [provider.id, null];
return [provider.id, undefined];
}

if (provider.apiKey && !config.api_key) {
return [provider.id, null];
return [provider.id, undefined];
}

const { base_url, api_key } = config;
Expand Down Expand Up @@ -249,7 +257,7 @@ function useConfiguredMapping(): Record<

return [provider.id, listModelsFunc];
}),
) as Record<string, null | (() => Promise<ListModelsResult>)>;
) as Record<string, undefined | (() => Promise<ListModelsResult>)>;
}, [configuredProviders, auth, billing]);

return mapping;
Expand Down
17 changes: 14 additions & 3 deletions apps/desktop/src/components/settings/ai/shared/list-anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ import { Effect, pipe, Schema } from "effect";

import {
DEFAULT_RESULT,
extractMetadataMap,
fetchJson,
type InputModality,
type ListModelsResult,
type ModelIgnoreReason,
partition,
Expand Down Expand Up @@ -39,8 +41,8 @@ export async function listAnthropicModels(
"anthropic-dangerous-direct-browser-access": "true",
}),
Effect.andThen((json) => Schema.decodeUnknown(AnthropicModelSchema)(json)),
Effect.map(({ data }) =>
partition(
Effect.map(({ data }) => ({
...partition(
data,
(model) => {
const reasons: ModelIgnoreReason[] = [];
Expand All @@ -51,9 +53,18 @@ export async function listAnthropicModels(
},
(model) => model.id,
),
),
metadata: extractMetadataMap(
data,
(model) => model.id,
(model) => ({ input_modalities: getInputModalities(model.id) }),
),
})),
Effect.timeout(REQUEST_TIMEOUT),
Effect.catchAll(() => Effect.succeed(DEFAULT_RESULT)),
Effect.runPromise,
);
}

const getInputModalities = (_modelId: string): InputModality[] => {
return ["text", "image"];
};
55 changes: 49 additions & 6 deletions apps/desktop/src/components/settings/ai/shared/list-common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,23 @@ export type ModelIgnoreReason =

export type IgnoredModel = { id: string; reasons: ModelIgnoreReason[] };

export type InputModality = "image" | "text";

export type ModelMetadata = {
input_modalities?: InputModality[];
};

export type ListModelsResult = {
models: string[];
ignored: IgnoredModel[];
metadata: Record<string, ModelMetadata>;
};

export const DEFAULT_RESULT: ListModelsResult = { models: [], ignored: [] };
export const DEFAULT_RESULT: ListModelsResult = {
models: [],
ignored: [],
metadata: {},
};
export const REQUEST_TIMEOUT = "5 seconds";

export const commonIgnoreKeywords = [
Expand Down Expand Up @@ -49,20 +60,52 @@ export const shouldIgnoreCommonKeywords = (id: string): boolean => {
return commonIgnoreKeywords.some((keyword) => lowerId.includes(keyword));
};

const hasMetadata = (metadata: ModelMetadata | undefined): boolean => {
if (!metadata) {
return false;
}
if (metadata.input_modalities && metadata.input_modalities.length > 0) {
return true;
}
return false;
};

export const partition = <T,>(
items: readonly T[],
shouldIgnore: (item: T) => ModelIgnoreReason[] | null,
extract: (item: T) => string,
): ListModelsResult => {
const result = { models: [] as string[], ignored: [] as IgnoredModel[] };
): { models: string[]; ignored: IgnoredModel[] } => {
const models: string[] = [];
const ignored: IgnoredModel[] = [];

for (const item of items) {
const reasons = shouldIgnore(item);
const id = extract(item);

if (!reasons || reasons.length === 0) {
result.models.push(extract(item));
models.push(id);
} else {
result.ignored.push({ id: extract(item), reasons });
ignored.push({ id, reasons });
}
}
return result;

return { models, ignored };
};

export const extractMetadataMap = <T,>(
items: readonly T[],
extract: (item: T) => string,
extractMetadata: (item: T) => ModelMetadata | undefined,
): Record<string, ModelMetadata> => {
const metadata: Record<string, ModelMetadata> = {};

for (const item of items) {
const id = extract(item);
const meta = extractMetadata(item);
if (hasMetadata(meta)) {
metadata[id] = meta!;
}
}

return metadata;
};
17 changes: 14 additions & 3 deletions apps/desktop/src/components/settings/ai/shared/list-google.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ import { Effect, pipe, Schema } from "effect";

import {
DEFAULT_RESULT,
extractMetadataMap,
fetchJson,
type InputModality,
type ListModelsResult,
type ModelIgnoreReason,
partition,
Expand Down Expand Up @@ -53,11 +55,20 @@ export async function listGoogleModels(
return pipe(
fetchJson(`${baseUrl}/models`, { "x-goog-api-key": apiKey }),
Effect.andThen((json) => Schema.decodeUnknown(GoogleModelSchema)(json)),
Effect.map(({ models }) =>
partition(models, getIgnoreReasons, extractModelId),
),
Effect.map(({ models }) => ({
...partition(models, getIgnoreReasons, extractModelId),
metadata: extractMetadataMap(models, extractModelId, (model) => ({
input_modalities: getInputModalities(extractModelId(model)),
})),
})),
Effect.timeout(REQUEST_TIMEOUT),
Effect.catchAll(() => Effect.succeed(DEFAULT_RESULT)),
Effect.runPromise,
);
}

const getInputModalities = (modelId: string): InputModality[] => {
const normalizedId = modelId.toLowerCase();
const supportsMultimodal = /gemini/.test(normalizedId);
return supportsMultimodal ? ["text", "image"] : ["text"];
};
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
type IgnoredModel,
type ListModelsResult,
type ModelIgnoreReason,
type ModelMetadata,
REQUEST_TIMEOUT,
} from "./list-common";

Expand Down Expand Up @@ -74,6 +75,7 @@ const processLMStudioModels = (
): ListModelsResult => {
const models: string[] = [];
const ignored: IgnoredModel[] = [];
const metadata: Record<string, ModelMetadata> = {};

for (const model of downloadedModels) {
const reasons: ModelIgnoreReason[] = [];
Expand All @@ -91,6 +93,8 @@ const processLMStudioModels = (

if (reasons.length === 0) {
models.push(model.path);
// TODO: Seems like LMStudio do not have way to know input modality.
metadata[model.path] = { input_modalities: ["text"] };
} else {
ignored.push({ id: model.path, reasons });
}
Expand All @@ -107,5 +111,5 @@ const processLMStudioModels = (
return aLoaded ? -1 : 1;
});

return { models, ignored };
return { models, ignored, metadata };
};
5 changes: 5 additions & 0 deletions apps/desktop/src/components/settings/ai/shared/list-ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
type IgnoredModel,
type ListModelsResult,
type ModelIgnoreReason,
type ModelMetadata,
REQUEST_TIMEOUT,
} from "./list-common";

Expand Down Expand Up @@ -90,13 +91,16 @@ const summarizeOllamaDetails = (
): ListModelsResult => {
const supported: Array<{ name: string; isRunning: boolean }> = [];
const ignored: IgnoredModel[] = [];
const metadata: Record<string, ModelMetadata> = {};

for (const detail of details) {
const hasCompletion = detail.capabilities.includes("completion");
const hasTools = detail.capabilities.includes("tools");

if (hasCompletion && hasTools) {
supported.push({ name: detail.name, isRunning: detail.isRunning });
// TODO: Seems like Ollama do not have way to know input modality.
metadata[detail.name] = { input_modalities: ["text"] };
} else {
const reasons: ModelIgnoreReason[] = [];
if (!hasCompletion) {
Expand All @@ -119,5 +123,6 @@ const summarizeOllamaDetails = (
return {
models: supported.map((detail) => detail.name),
ignored,
metadata,
};
};
23 changes: 17 additions & 6 deletions apps/desktop/src/components/settings/ai/shared/list-openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { Effect, pipe, Schema } from "effect";

import {
DEFAULT_RESULT,
extractMetadataMap,
fetchJson,
type ListModelsResult,
type ModelIgnoreReason,
Expand Down Expand Up @@ -29,8 +30,8 @@ export async function listOpenAIModels(
return pipe(
fetchJson(`${baseUrl}/models`, { Authorization: `Bearer ${apiKey}` }),
Effect.andThen((json) => Schema.decodeUnknown(OpenAIModelSchema)(json)),
Effect.map(({ data }) =>
partition(
Effect.map(({ data }) => ({
...partition(
data,
(model) => {
const reasons: ModelIgnoreReason[] = [];
Expand All @@ -41,7 +42,12 @@ export async function listOpenAIModels(
},
(model) => model.id,
),
),
metadata: extractMetadataMap(
data,
(model) => model.id,
(_model) => ({ input_modalities: ["text", "image"] }),
),
})),
Effect.timeout(REQUEST_TIMEOUT),
Effect.catchAll(() => Effect.succeed(DEFAULT_RESULT)),
Effect.runPromise,
Expand All @@ -59,8 +65,8 @@ export async function listGenericModels(
return pipe(
fetchJson(`${baseUrl}/models`, { Authorization: `Bearer ${apiKey}` }),
Effect.andThen((json) => Schema.decodeUnknown(OpenAIModelSchema)(json)),
Effect.map(({ data }) =>
partition(
Effect.map(({ data }) => ({
...partition(
data,
(model) => {
const reasons: ModelIgnoreReason[] = [];
Expand All @@ -71,7 +77,12 @@ export async function listGenericModels(
},
(model) => model.id,
),
),
metadata: extractMetadataMap(
data,
(model) => model.id,
() => ({ input_modalities: ["text"] }),
),
})),
Effect.timeout(REQUEST_TIMEOUT),
Effect.catchAll(() => Effect.succeed(DEFAULT_RESULT)),
Effect.runPromise,
Expand Down
Loading
Loading