diff --git a/apps/desktop/src/components/settings/ai/llm/configure.tsx b/apps/desktop/src/components/settings/ai/llm/configure.tsx index d591dae9cd..1c186d3fa0 100644 --- a/apps/desktop/src/components/settings/ai/llm/configure.tsx +++ b/apps/desktop/src/components/settings/ai/llm/configure.tsx @@ -201,7 +201,9 @@ function ProviderContext({ providerId }: { providerId: ProviderId }) { ? "We only support **OpenAI-compatible** endpoints for now." : providerId === "openrouter" ? "We filter out models from the combobox based on heuristics like **input modalities** and **tool support**." - : ""; + : providerId === "google_generative_ai" + ? "Visit [AI Studio](https://aistudio.google.com/api-keys) to create an API key." + : ""; if (providerId === "hyprnote" && !isPro) { return ( diff --git a/apps/desktop/src/components/settings/ai/llm/select.tsx b/apps/desktop/src/components/settings/ai/llm/select.tsx index a36a24eaae..c5db712928 100644 --- a/apps/desktop/src/components/settings/ai/llm/select.tsx +++ b/apps/desktop/src/components/settings/ai/llm/select.tsx @@ -15,6 +15,7 @@ import { useBillingAccess } from "../../../../billing"; import { useConfigValues } from "../../../../config/use-config"; import * as main from "../../../../store/tinybase/main"; import type { ListModelsResult } from "../shared/list-common"; +import { listGoogleModels } from "../shared/list-google"; import { listLMStudioModels } from "../shared/list-lmstudio"; import { listOllamaModels } from "../shared/list-ollama"; import { @@ -233,6 +234,9 @@ function useConfiguredMapping(): Record< case "openrouter": listModelsFunc = () => listOpenRouterModels(baseUrl, apiKey); break; + case "google_generative_ai": + listModelsFunc = () => listGoogleModels(baseUrl, apiKey); + break; case "ollama": listModelsFunc = () => listOllamaModels(baseUrl, apiKey); break; diff --git a/apps/desktop/src/components/settings/ai/llm/shared.tsx b/apps/desktop/src/components/settings/ai/llm/shared.tsx index 477cee7f56..0a17fceb7f 100644 --- a/apps/desktop/src/components/settings/ai/llm/shared.tsx +++ b/apps/desktop/src/components/settings/ai/llm/shared.tsx @@ -57,6 +57,15 @@ export const PROVIDERS = [ baseUrl: "https://api.anthropic.com/v1", requiresPro: false, }, + { + id: "google_generative_ai", + displayName: "Google Gemini", + badge: null, + icon: , + apiKey: true, + baseUrl: "https://generativelanguage.googleapis.com/v1beta", + requiresPro: false, + }, { id: "custom", displayName: "Custom", diff --git a/apps/desktop/src/components/settings/ai/shared/list-common.ts b/apps/desktop/src/components/settings/ai/shared/list-common.ts index 77377684cc..cd16714c3e 100644 --- a/apps/desktop/src/components/settings/ai/shared/list-common.ts +++ b/apps/desktop/src/components/settings/ai/shared/list-common.ts @@ -26,6 +26,8 @@ export const commonIgnoreKeywords = [ "dall-e", "audio", "image", + "computer", + "robotics", ] as const; export const fetchJson = (url: string, headers: Record) => diff --git a/apps/desktop/src/components/settings/ai/shared/list-google.ts b/apps/desktop/src/components/settings/ai/shared/list-google.ts new file mode 100644 index 0000000000..c2f133246b --- /dev/null +++ b/apps/desktop/src/components/settings/ai/shared/list-google.ts @@ -0,0 +1,63 @@ +import { Effect, pipe, Schema } from "effect"; + +import { + DEFAULT_RESULT, + fetchJson, + type ListModelsResult, + type ModelIgnoreReason, + partition, + REQUEST_TIMEOUT, + shouldIgnoreCommonKeywords, +} from "./list-common"; + +const GoogleModelSchema = Schema.Struct({ + models: Schema.Array( + Schema.Struct({ + name: Schema.String, + supportedGenerationMethods: Schema.optional(Schema.Array(Schema.String)), + }), + ), +}); + +type GoogleModel = Schema.Schema.Type< + typeof GoogleModelSchema +>["models"][number]; + +export async function listGoogleModels( + baseUrl: string, + apiKey: string, +): Promise { + if (!baseUrl) { + return DEFAULT_RESULT; + } + + const supportsGeneration = (model: GoogleModel): boolean => + !model.supportedGenerationMethods || + model.supportedGenerationMethods.includes("generateContent"); + + const getIgnoreReasons = (model: GoogleModel): ModelIgnoreReason[] | null => { + const reasons: ModelIgnoreReason[] = []; + if (shouldIgnoreCommonKeywords(model.name)) { + reasons.push("common_keyword"); + } + if (!supportsGeneration(model)) { + reasons.push("no_completion"); + } + return reasons.length > 0 ? reasons : null; + }; + + const extractModelId = (model: GoogleModel): string => { + return model.name.replace(/^models\//, ""); + }; + + return pipe( + fetchJson(`${baseUrl}/models`, { "x-goog-api-key": apiKey }), + Effect.andThen((json) => Schema.decodeUnknown(GoogleModelSchema)(json)), + Effect.map(({ models }) => + partition(models, getIgnoreReasons, extractModelId), + ), + Effect.timeout(REQUEST_TIMEOUT), + Effect.catchAll(() => Effect.succeed(DEFAULT_RESULT)), + Effect.runPromise, + ); +} diff --git a/apps/desktop/src/hooks/useLLMConnection.ts b/apps/desktop/src/hooks/useLLMConnection.ts index d2fe214258..c2b070ab5a 100644 --- a/apps/desktop/src/hooks/useLLMConnection.ts +++ b/apps/desktop/src/hooks/useLLMConnection.ts @@ -1,4 +1,5 @@ import { createAnthropic } from "@ai-sdk/anthropic"; +import { createGoogleGenerativeAI } from "@ai-sdk/google"; import { createOpenAI } from "@ai-sdk/openai"; import { createOpenAICompatible } from "@ai-sdk/openai-compatible"; import { createOpenRouter } from "@openrouter/ai-sdk-provider"; @@ -76,6 +77,16 @@ export const useLanguageModel = (): Exclude | null => { return wrapWithThinkingMiddleware(anthropicProvider(conn.modelId)); } + if (conn.providerId === "google_generative_ai") { + const googleProvider = createGoogleGenerativeAI({ + fetch: tauriFetch, + baseURL: conn.baseUrl, + apiKey: conn.apiKey, + }); + + return wrapWithThinkingMiddleware(googleProvider(conn.modelId)); + } + if (conn.providerId === "openrouter") { const openRouterProvider = createOpenRouter({ fetch: tauriFetch,