Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions apps/desktop/src/components/settings/ai/llm/configure.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,13 @@ function ProviderContext({ providerId }: { providerId: ProviderId }) {
? "We only support **OpenAI-compatible** endpoints for now."
: providerId === "openrouter"
? "We filter out models from the combobox based on heuristics like **input modalities** and **tool support**."
: providerId === "google_generative_ai"
? "Visit [AI Studio](https://aistudio.google.com/api-keys) to create an API key."
: "";
: providerId === "azure_openai"
? "Enter your **Azure OpenAI endpoint** (e.g. `https://your-resource.openai.azure.com`) as the Base URL and your **API key**. [Report issues](https://github.com/fastrepl/char/issues/3928)"
: providerId === "azure_ai"
? "Enter your **Azure AI Foundry endpoint** as the Base URL and your **API key**. Supports Claude and other models deployed via Azure AI Foundry. [Report issues](https://github.com/fastrepl/char/issues/3928)"
: providerId === "google_generative_ai"
? "Visit [AI Studio](https://aistudio.google.com/api-keys) to create an API key."
: "";

if (!content) {
return null;
Expand Down
8 changes: 8 additions & 0 deletions apps/desktop/src/components/settings/ai/llm/select.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import {
requiresEntitlement,
} from "../shared/eligibility";
import { listAnthropicModels } from "../shared/list-anthropic";
import { listAzureAIModels } from "../shared/list-azure-ai";
import { listAzureOpenAIModels } from "../shared/list-azure-openai";
import {
type InputModality,
type ListModelsResult,
Expand Down Expand Up @@ -269,6 +271,12 @@ function useConfiguredMapping(): Record<string, ProviderStatus> {
case "mistral":
listModelsFunc = () => listMistralModels(baseUrl, apiKey);
break;
case "azure_openai":
listModelsFunc = () => listAzureOpenAIModels(baseUrl, apiKey);
break;
case "azure_ai":
listModelsFunc = () => listAzureAIModels(baseUrl, apiKey);
break;
case "ollama":
listModelsFunc = () => listOllamaModels(baseUrl, apiKey);
break;
Expand Down
22 changes: 22 additions & 0 deletions apps/desktop/src/components/settings/ai/llm/shared.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import { Icon } from "@iconify-icon/react";
import {
Anthropic,
Azure,
AzureAI,
LmStudio,
Mistral,
Ollama,
Expand Down Expand Up @@ -110,6 +112,26 @@ const _PROVIDERS = [
baseUrl: "https://api.mistral.ai/v1",
requirements: [{ kind: "requires_config", fields: ["api_key"] }],
},
{
id: "azure_openai",
displayName: "Azure OpenAI",
badge: "Beta",
icon: <Azure size={16} />,
baseUrl: undefined,
requirements: [
{ kind: "requires_config", fields: ["base_url", "api_key"] },
],
},
{
id: "azure_ai",
displayName: "Azure AI Foundry",
badge: "Beta",
icon: <AzureAI size={16} />,
baseUrl: undefined,
requirements: [
{ kind: "requires_config", fields: ["base_url", "api_key"] },
],
},
{
id: "google_generative_ai",
displayName: "Google Gemini",
Expand Down
60 changes: 60 additions & 0 deletions apps/desktop/src/components/settings/ai/shared/list-azure-ai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import { Effect, pipe, Schema } from "effect";

import {
DEFAULT_RESULT,
extractMetadataMap,
fetchJson,
type ListModelsResult,
type ModelIgnoreReason,
partition,
REQUEST_TIMEOUT,
shouldIgnoreCommonKeywords,
} from "./list-common";

const AzureAIDeploymentSchema = Schema.Struct({
data: Schema.Array(
Schema.Struct({
id: Schema.String,
model: Schema.optional(Schema.String),
}),
),
});

export async function listAzureAIModels(
baseUrl: string,
apiKey: string,
): Promise<ListModelsResult> {
if (!baseUrl) {
return DEFAULT_RESULT;
}

const url = `${baseUrl.replace(/\/+$/, "")}/models`;

return pipe(
fetchJson(url, { "api-key": apiKey }),
Effect.andThen((json) =>
Schema.decodeUnknown(AzureAIDeploymentSchema)(json),
),
Effect.map(({ data }) => ({
...partition(
data,
(model) => {
const reasons: ModelIgnoreReason[] = [];
if (shouldIgnoreCommonKeywords(model.id)) {
reasons.push("common_keyword");
}
return reasons.length > 0 ? reasons : null;
},
(model) => model.id,
),
metadata: extractMetadataMap(
data,
(model) => model.id,
(_model) => ({ input_modalities: ["text", "image"] }),
),
})),
Effect.timeout(REQUEST_TIMEOUT),
Effect.catchAll(() => Effect.succeed(DEFAULT_RESULT)),
Effect.runPromise,
);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import { Effect, pipe, Schema } from "effect";

import {
DEFAULT_RESULT,
extractMetadataMap,
fetchJson,
isDateSnapshot,
isOldModel,
type ListModelsResult,
type ModelIgnoreReason,
partition,
REQUEST_TIMEOUT,
shouldIgnoreCommonKeywords,
} from "./list-common";

const AzureOpenAIModelSchema = Schema.Struct({
data: Schema.Array(
Schema.Struct({
id: Schema.String,
capabilities: Schema.optional(
Schema.Struct({
chat_completion: Schema.optional(Schema.Boolean),
completion: Schema.optional(Schema.Boolean),
embeddings: Schema.optional(Schema.Boolean),
inference: Schema.optional(Schema.Boolean),
}),
),
}),
),
});

export async function listAzureOpenAIModels(
baseUrl: string,
apiKey: string,
): Promise<ListModelsResult> {
if (!baseUrl) {
return DEFAULT_RESULT;
}

const url = `${baseUrl.replace(/\/+$/, "")}/openai/models?api-version=2024-10-21`;

return pipe(
fetchJson(url, { "api-key": apiKey }),
Effect.andThen((json) =>
Schema.decodeUnknown(AzureOpenAIModelSchema)(json),
),
Effect.map(({ data }) => ({
...partition(
data,
(model) => {
const reasons: ModelIgnoreReason[] = [];
if (shouldIgnoreCommonKeywords(model.id)) {
reasons.push("common_keyword");
}
if (isOldModel(model.id)) {
reasons.push("old_model");
}
if (isDateSnapshot(model.id)) {
reasons.push("date_snapshot");
}
if (
model.capabilities &&
model.capabilities.chat_completion === false &&
model.capabilities.completion === false
) {
reasons.push("not_chat_model");
}
return reasons.length > 0 ? reasons : null;
},
(model) => model.id,
),
metadata: extractMetadataMap(
data,
(model) => model.id,
(_model) => ({ input_modalities: ["text", "image"] }),
),
})),
Effect.timeout(REQUEST_TIMEOUT),
Effect.catchAll(() => Effect.succeed(DEFAULT_RESULT)),
Effect.runPromise,
);
}
21 changes: 21 additions & 0 deletions apps/desktop/src/hooks/useLLMConnection.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { createAnthropic } from "@ai-sdk/anthropic";
import { createAzure } from "@ai-sdk/azure";
import { createGoogleGenerativeAI } from "@ai-sdk/google";
import { createOpenAI } from "@ai-sdk/openai";
import { createOpenAICompatible } from "@ai-sdk/openai-compatible";
Expand Down Expand Up @@ -283,6 +284,26 @@ const createLanguageModel = (
return wrapWithThinkingMiddleware(provider(conn.modelId));
}

case "azure_openai": {
const provider = createAzure({
fetch: tauriFetch,
baseURL: conn.baseUrl,
apiKey: conn.apiKey,
});
return wrapWithThinkingMiddleware(provider(conn.modelId));
}

case "azure_ai": {
const provider = createOpenAICompatible({
fetch: tauriFetch,
name: "azure_ai",
baseURL: conn.baseUrl,
apiKey: conn.apiKey,
headers: { "api-key": conn.apiKey },
});
return wrapWithThinkingMiddleware(provider.chatModel(conn.modelId));
}

case "ollama": {
const ollamaOrigin = new URL(conn.baseUrl.replace(/\/v1\/?$/, "")).origin;
const ollamaFetch: typeof fetch = async (input, init) => {
Expand Down