Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions packages/inference/src/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ export class InferenceClientInputError extends InferenceClientError {
}
}

export class InferenceClientRoutingError extends InferenceClientError {
constructor(message: string) {
super(message);
this.name = "RoutingError";
}
}

interface HttpRequest {
url: string;
method: string;
Expand Down
11 changes: 11 additions & 0 deletions packages/inference/src/lib/getInferenceProviderMapping.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,17 @@ export async function getInferenceProviderMapping(
}
): Promise<InferenceProviderMappingEntry | null> {
const logger = getLogger();
if (params.provider === ("auto" as InferenceProvider) && params.task === "conversational") {
// Special case for auto + conversational to avoid extra API calls
// Call directly the server-side auto router
return {
hfModelId: params.modelId,
provider: "auto",
providerId: params.modelId,
status: "live",
task: "conversational",
};
}
if (HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId]) {
return HARDCODED_MODEL_INFERENCE_MAPPING[params.provider][params.modelId];
}
Expand Down
17 changes: 15 additions & 2 deletions packages/inference/src/providers/providerHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import type {
ZeroShotImageClassificationOutput,
} from "@huggingface/tasks";
import { HF_ROUTER_URL } from "../config.js";
import { InferenceClientProviderOutputError } from "../errors.js";
import { InferenceClientProviderOutputError, InferenceClientRoutingError } from "../errors.js";
import type { AudioToAudioOutput } from "../tasks/audio/audioToAudio.js";
import type { BaseArgs, BodyParams, HeaderParams, InferenceProvider, RequestArgs, UrlParams } from "../types.js";
import { toArray } from "../utils/toArray.js";
Expand All @@ -62,7 +62,7 @@ import type { ImageSegmentationArgs } from "../tasks/cv/imageSegmentation.js";
export abstract class TaskProviderHelper {
constructor(
readonly provider: InferenceProvider,
private baseUrl: string,
protected baseUrl: string,
readonly clientSideRoutingOnly: boolean = false
) {}

Expand Down Expand Up @@ -369,3 +369,16 @@ export class BaseTextGenerationTask extends TaskProviderHelper implements TextGe
throw new InferenceClientProviderOutputError("Expected Array<{generated_text: string}>");
}
}

export class AutoRouterConversationalTask extends BaseConversationalTask {
constructor() {
super("auto" as InferenceProvider, "https://router.huggingface.co");
}

override makeBaseUrl(params: UrlParams): string {
if (params.authMethod !== "hf-token") {
throw new InferenceClientRoutingError("Cannot select auto-router when using non-Hugging Face API key.");
}
return this.baseUrl;
}
}
12 changes: 10 additions & 2 deletions packages/inference/src/tasks/nlp/chatCompletion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import { resolveProvider } from "../../lib/getInferenceProviderMapping.js";
import { getProviderHelper } from "../../lib/getProviderHelper.js";
import type { BaseArgs, Options } from "../../types.js";
import { innerRequest } from "../../utils/request.js";
import type { ConversationalTaskHelper, TaskProviderHelper } from "../../providers/providerHelper.js";
import { AutoRouterConversationalTask } from "../../providers/providerHelper.js";

/**
* Use the chat completion endpoint to generate a response to a prompt, using OpenAI message completion API no stream
Expand All @@ -11,8 +13,14 @@ export async function chatCompletion(
args: BaseArgs & ChatCompletionInput,
options?: Options
): Promise<ChatCompletionOutput> {
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
const providerHelper = getProviderHelper(provider, "conversational");
let providerHelper: ConversationalTaskHelper & TaskProviderHelper;
if (!args.provider || args.provider === "auto") {
// Special case: we have a dedicated auto-router for conversational models. No need to fetch provider mapping.
providerHelper = new AutoRouterConversationalTask();
} else {
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
providerHelper = getProviderHelper(provider, "conversational");
}
const { data: response } = await innerRequest<ChatCompletionOutput>(args, providerHelper, {
...options,
task: "conversational",
Expand Down