diff --git a/packages/inference/src/errors.ts b/packages/inference/src/errors.ts index 01c23d976e..bcc0a773ce 100644 --- a/packages/inference/src/errors.ts +++ b/packages/inference/src/errors.ts @@ -17,6 +17,13 @@ export class InferenceClientInputError extends InferenceClientError { } } +export class InferenceClientRoutingError extends InferenceClientError { + constructor(message: string) { + super(message); + this.name = "RoutingError"; + } +} + interface HttpRequest { url: string; method: string; diff --git a/packages/inference/src/lib/getInferenceProviderMapping.ts b/packages/inference/src/lib/getInferenceProviderMapping.ts index e4cb4e90df..f6e0dd5656 100644 --- a/packages/inference/src/lib/getInferenceProviderMapping.ts +++ b/packages/inference/src/lib/getInferenceProviderMapping.ts @@ -124,6 +124,17 @@ export async function getInferenceProviderMapping( } ): Promise { const logger = getLogger(); + if (params.provider === ("auto" as InferenceProvider) && params.task === "conversational") { + // Special case for auto + conversational to avoid extra API calls + // Call directly the server-side auto router + return { + hfModelId: params.modelId, + provider: "auto", + providerId: params.modelId, + status: "live", + task: "conversational", + }; + } if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) { return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]; } diff --git a/packages/inference/src/providers/providerHelper.ts b/packages/inference/src/providers/providerHelper.ts index fc1ebc25f8..c92d106317 100644 --- a/packages/inference/src/providers/providerHelper.ts +++ b/packages/inference/src/providers/providerHelper.ts @@ -47,7 +47,7 @@ import type { ZeroShotImageClassificationOutput, } from "@huggingface/tasks"; import { HF_ROUTER_URL } from "../config.js"; -import { InferenceClientProviderOutputError } from "../errors.js"; +import { InferenceClientProviderOutputError, InferenceClientRoutingError } from "../errors.js"; import type { AudioToAudioOutput } from "../tasks/audio/audioToAudio.js"; import type { BaseArgs, BodyParams, HeaderParams, InferenceProvider, RequestArgs, UrlParams } from "../types.js"; import { toArray } from "../utils/toArray.js"; @@ -62,7 +62,7 @@ import type { ImageSegmentationArgs } from "../tasks/cv/imageSegmentation.js"; export abstract class TaskProviderHelper { constructor( readonly provider: InferenceProvider, - private baseUrl: string, + protected baseUrl: string, readonly clientSideRoutingOnly: boolean = false ) {} @@ -369,3 +369,16 @@ export class BaseTextGenerationTask extends TaskProviderHelper implements TextGe throw new InferenceClientProviderOutputError("Expected Array<{generated_text: string}>"); } } + +export class AutoRouterConversationalTask extends BaseConversationalTask { + constructor() { + super("auto" as InferenceProvider, "https://router.huggingface.co"); + } + + override makeBaseUrl(params: UrlParams): string { + if (params.authMethod !== "hf-token") { + throw new InferenceClientRoutingError("Cannot select auto-router when using non-Hugging Face API key."); + } + return this.baseUrl; + } +} diff --git a/packages/inference/src/tasks/nlp/chatCompletion.ts b/packages/inference/src/tasks/nlp/chatCompletion.ts index 0bfad236c7..ca3cefb0cf 100644 --- a/packages/inference/src/tasks/nlp/chatCompletion.ts +++ b/packages/inference/src/tasks/nlp/chatCompletion.ts @@ -3,6 +3,8 @@ import { resolveProvider } from "../../lib/getInferenceProviderMapping.js"; import { getProviderHelper } from "../../lib/getProviderHelper.js"; import type { BaseArgs, Options } from "../../types.js"; import { innerRequest } from "../../utils/request.js"; +import type { ConversationalTaskHelper, TaskProviderHelper } from "../../providers/providerHelper.js"; +import { AutoRouterConversationalTask } from "../../providers/providerHelper.js"; /** * Use the chat completion endpoint to generate a response to a prompt, using OpenAI message completion API no stream @@ -11,8 +13,14 @@ export async function chatCompletion( args: BaseArgs & ChatCompletionInput, options?: Options ): Promise { - const provider = await resolveProvider(args.provider, args.model, args.endpointUrl); - const providerHelper = getProviderHelper(provider, "conversational"); + let providerHelper: ConversationalTaskHelper & TaskProviderHelper; + if (!args.provider || args.provider === "auto") { + // Special case: we have a dedicated auto-router for conversational models. No need to fetch provider mapping. + providerHelper = new AutoRouterConversationalTask(); + } else { + const provider = await resolveProvider(args.provider, args.model, args.endpointUrl); + providerHelper = getProviderHelper(provider, "conversational"); + } const { data: response } = await innerRequest(args, providerHelper, { ...options, task: "conversational",