Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { name as packageName, version as packageVersion } from "../../package.json";
import { HF_HEADER_X_BILL_TO, HF_HUB_URL } from "../config";
import type { InferenceTask, Options, RequestArgs } from "../types";
import { getProviderHelper } from "./getProviderHelper";
import type { getProviderHelper } from "./getProviderHelper";
import { getProviderModelId } from "./getProviderModelId";
import { isUrl } from "./isUrl";

Expand All @@ -20,6 +20,7 @@ export async function makeRequestOptions(
data?: Blob | ArrayBuffer;
stream?: boolean;
},
providerHelper: ReturnType<typeof getProviderHelper>,
options?: Options & {
/** In most cases (unless we pass a endpointUrl) we know the task */
task?: InferenceTask;
Expand All @@ -28,6 +29,7 @@ export async function makeRequestOptions(
const { provider: maybeProvider, model: maybeModel } = args;
const provider = maybeProvider ?? "hf-inference";
const { task } = options ?? {};

// Validate inputs
if (args.endpointUrl && provider !== "hf-inference") {
throw new Error(`Cannot use endpointUrl with a third-party provider.`);
Expand All @@ -38,7 +40,7 @@ export async function makeRequestOptions(

if (args.endpointUrl) {
// No need to have maybeModel, or to load default model for a task
return makeRequestOptionsFromResolvedModel(maybeModel ?? args.endpointUrl, args, options);
return makeRequestOptionsFromResolvedModel(maybeModel ?? args.endpointUrl, providerHelper, args, options);
}

if (!maybeModel && !task) {
Expand All @@ -47,7 +49,6 @@ export async function makeRequestOptions(

// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const hfModel = maybeModel ?? (await loadDefaultModel(task!));
const providerHelper = getProviderHelper(provider, task);

if (providerHelper.clientSideRoutingOnly && !maybeModel) {
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
Expand All @@ -62,7 +63,7 @@ export async function makeRequestOptions(
});

// Use the sync version with the resolved model
return makeRequestOptionsFromResolvedModel(resolvedModel, args, options);
return makeRequestOptionsFromResolvedModel(resolvedModel, providerHelper, args, options);
}

/**
Expand All @@ -71,6 +72,7 @@ export async function makeRequestOptions(
*/
export function makeRequestOptionsFromResolvedModel(
resolvedModel: string,
providerHelper: ReturnType<typeof getProviderHelper>,
args: RequestArgs & {
data?: Blob | ArrayBuffer;
stream?: boolean;
Expand All @@ -85,7 +87,6 @@ export function makeRequestOptionsFromResolvedModel(
const provider = maybeProvider ?? "hf-inference";

const { includeCredentials, task, signal, billTo } = options ?? {};
const providerHelper = getProviderHelper(provider, task);
const authMethod = (() => {
if (providerHelper.clientSideRoutingOnly) {
// Closed-source providers require an accessToken (cannot be routed).
Expand Down
12 changes: 6 additions & 6 deletions packages/inference/src/providers/hf-inference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -385,13 +385,13 @@ export class HFInferenceQuestionAnsweringTask extends HFInferenceTask implements
typeof elem.end === "number" &&
typeof elem.score === "number" &&
typeof elem.start === "number"
)
)
: typeof response === "object" &&
!!response &&
typeof response.answer === "string" &&
typeof response.end === "number" &&
typeof response.score === "number" &&
typeof response.start === "number"
!!response &&
typeof response.answer === "string" &&
typeof response.end === "number" &&
typeof response.score === "number" &&
typeof response.start === "number"
) {
return Array.isArray(response) ? response[0] : response;
}
Expand Down
9 changes: 9 additions & 0 deletions packages/inference/src/snippets/getInferenceSnippets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import type { ChatCompletionInputMessage, GenerationParameters } from "@huggingf
import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions";
import type { InferenceProvider, InferenceTask, RequestArgs } from "../types";
import { templates } from "./templates.exported";
import { getProviderHelper } from "../lib/getProviderHelper";

const PYTHON_CLIENTS = ["huggingface_hub", "fal_client", "requests", "openai"] as const;
const JS_CLIENTS = ["fetch", "huggingface.js", "openai"] as const;
Expand Down Expand Up @@ -130,10 +131,18 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
inputPreparationFn = prepareConversationalInput;
task = "conversational";
}
let providerHelper: ReturnType<typeof getProviderHelper>;
try {
providerHelper = getProviderHelper(provider, task);
} catch (e) {
console.error(`Failed to get provider helper for ${provider} (${task})`, e);
return [];
}
/// Prepare inputs + make request
const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) };
const request = makeRequestOptionsFromResolvedModel(
providerModelId ?? model.id,
providerHelper,
{
accessToken: accessToken,
provider: provider,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export async function audioClassification(
): Promise<AudioClassificationOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification");
const payload = preparePayload(args);
const { data: res } = await innerRequest<AudioClassificationOutput>(payload, {
const { data: res } = await innerRequest<AudioClassificationOutput>(payload, providerHelper, {
...options,
task: "audio-classification",
});
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/audio/audioToAudio.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ export interface AudioToAudioOutput {
export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioOutput[]> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio");
const payload = preparePayload(args);
const { data: res } = await innerRequest<AudioToAudioOutput>(payload, {
const { data: res } = await innerRequest<AudioToAudioOutput>(payload, providerHelper, {
...options,
task: "audio-to-audio",
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export async function automaticSpeechRecognition(
): Promise<AutomaticSpeechRecognitionOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition");
const payload = await buildPayload(args);
const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, {
const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, providerHelper, {
...options,
task: "automatic-speech-recognition",
});
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/audio/textToSpeech.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ interface OutputUrlTextToSpeechGeneration {
export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> {
const provider = args.provider ?? "hf-inference";
const providerHelper = getProviderHelper(provider, "text-to-speech");
const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, {
const { data: res } = await innerRequest<Blob | OutputUrlTextToSpeechGeneration>(args, providerHelper, {
...options,
task: "text-to-speech",
});
Expand Down
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/custom/request.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { InferenceTask, Options, RequestArgs } from "../../types";
import { innerRequest } from "../../utils/request";

Expand All @@ -15,6 +16,7 @@ export async function request<T>(
console.warn(
"The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
);
const result = await innerRequest<T>(args, options);
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
const result = await innerRequest<T>(args, providerHelper, options);
return result.data;
}
5 changes: 4 additions & 1 deletion packages/inference/src/tasks/custom/streamingRequest.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { InferenceTask, Options, RequestArgs } from "../../types";
import { innerStreamingRequest } from "../../utils/request";

/**
* Primitive to make custom inference calls that expect server-sent events, and returns the response through a generator
* @deprecated Use specific task functions instead. This function will be removed in a future version.
Expand All @@ -14,5 +16,6 @@ export async function* streamingRequest<T>(
console.warn(
"The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead."
);
yield* innerStreamingRequest(args, options);
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task);
yield* innerStreamingRequest(args, providerHelper, options);
}
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/cv/imageClassification.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export async function imageClassification(
): Promise<ImageClassificationOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification");
const payload = preparePayload(args);
const { data: res } = await innerRequest<ImageClassificationOutput>(payload, {
const { data: res } = await innerRequest<ImageClassificationOutput>(payload, providerHelper, {
...options,
task: "image-classification",
});
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/cv/imageSegmentation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export async function imageSegmentation(
): Promise<ImageSegmentationOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation");
const payload = preparePayload(args);
const { data: res } = await innerRequest<ImageSegmentationOutput>(payload, {
const { data: res } = await innerRequest<ImageSegmentationOutput>(payload, providerHelper, {
...options,
task: "image-segmentation",
});
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/cv/imageToImage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export async function imageToImage(args: ImageToImageArgs, options?: Options): P
),
};
}
const { data: res } = await innerRequest<Blob>(reqArgs, {
const { data: res } = await innerRequest<Blob>(reqArgs, providerHelper, {
...options,
task: "image-to-image",
});
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/cv/imageToText.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput);
export async function imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text");
const payload = preparePayload(args);
const { data: res } = await innerRequest<[ImageToTextOutput]>(payload, {
const { data: res } = await innerRequest<[ImageToTextOutput]>(payload, providerHelper, {
...options,
task: "image-to-text",
});
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/cv/objectDetection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImage
export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection");
const payload = preparePayload(args);
const { data: res } = await innerRequest<ObjectDetectionOutput>(payload, {
const { data: res } = await innerRequest<ObjectDetectionOutput>(payload, providerHelper, {
...options,
task: "object-detection",
});
Expand Down
4 changes: 2 additions & 2 deletions packages/inference/src/tasks/cv/textToImage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ export async function textToImage(
export async function textToImage(args: TextToImageArgs, options?: TextToImageOptions): Promise<Blob | string> {
const provider = args.provider ?? "hf-inference";
const providerHelper = getProviderHelper(provider, "text-to-image");
const { data: res } = await innerRequest<Record<string, unknown>>(args, {
const { data: res } = await innerRequest<Record<string, unknown>>(args, providerHelper, {
...options,
task: "text-to-image",
});

const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-image" });
const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "text-to-image" });
return providerHelper.getResponse(res, url, info.headers as Record<string, string>, options?.outputType);
}
14 changes: 9 additions & 5 deletions packages/inference/src/tasks/cv/textToVideo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@ export type TextToVideoOutput = Blob;
export async function textToVideo(args: TextToVideoArgs, options?: Options): Promise<TextToVideoOutput> {
const provider = args.provider ?? "hf-inference";
const providerHelper = getProviderHelper(provider, "text-to-video");
const { data: response } = await innerRequest<FalAiQueueOutput | ReplicateOutput | NovitaOutput>(args, {
...options,
task: "text-to-video",
});
const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-video" });
const { data: response } = await innerRequest<FalAiQueueOutput | ReplicateOutput | NovitaOutput>(
args,
providerHelper,
{
...options,
task: "text-to-video",
}
);
const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "text-to-video" });
return providerHelper.getResponse(response, url, info.headers as Record<string, string>);
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export async function zeroShotImageClassification(
): Promise<ZeroShotImageClassificationOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification");
const payload = await preparePayload(args);
const { data: res } = await innerRequest<ZeroShotImageClassificationOutput>(payload, {
const { data: res } = await innerRequest<ZeroShotImageClassificationOutput>(payload, providerHelper, {
...options,
task: "zero-shot-image-classification",
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export async function documentQuestionAnswering(
} as RequestArgs;
const { data: res } = await innerRequest<DocumentQuestionAnsweringOutput | DocumentQuestionAnsweringOutput[number]>(
reqArgs,
providerHelper,
{
...options,
task: "document-question-answering",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export async function visualQuestionAnswering(
},
} as RequestArgs;

const { data: res } = await innerRequest<VisualQuestionAnsweringOutput>(reqArgs, {
const { data: res } = await innerRequest<VisualQuestionAnsweringOutput>(reqArgs, providerHelper, {
...options,
task: "visual-question-answering",
});
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/nlp/chatCompletion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export async function chatCompletion(
options?: Options
): Promise<ChatCompletionOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
const { data: response } = await innerRequest<ChatCompletionOutput>(args, {
const { data: response } = await innerRequest<ChatCompletionOutput>(args, providerHelper, {
...options,
task: "conversational",
});
Expand Down
4 changes: 3 additions & 1 deletion packages/inference/src/tasks/nlp/chatCompletionStream.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ChatCompletionInput, ChatCompletionStreamOutput } from "@huggingface/tasks";
import { getProviderHelper } from "../../lib/getProviderHelper";
import type { BaseArgs, Options } from "../../types";
import { innerStreamingRequest } from "../../utils/request";

Expand All @@ -9,7 +10,8 @@ export async function* chatCompletionStream(
args: BaseArgs & ChatCompletionInput,
options?: Options
): AsyncGenerator<ChatCompletionStreamOutput> {
yield* innerStreamingRequest<ChatCompletionStreamOutput>(args, {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational");
yield* innerStreamingRequest<ChatCompletionStreamOutput>(args, providerHelper, {
...options,
task: "conversational",
});
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/nlp/featureExtraction.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export async function featureExtraction(
options?: Options
): Promise<FeatureExtractionOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "feature-extraction");
const { data: res } = await innerRequest<FeatureExtractionOutput>(args, {
const { data: res } = await innerRequest<FeatureExtractionOutput>(args, providerHelper, {
...options,
task: "feature-extraction",
});
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/nlp/fillMask.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export type FillMaskArgs = BaseArgs & FillMaskInput;
*/
export async function fillMask(args: FillMaskArgs, options?: Options): Promise<FillMaskOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "fill-mask");
const { data: res } = await innerRequest<FillMaskOutput>(args, {
const { data: res } = await innerRequest<FillMaskOutput>(args, providerHelper, {
...options,
task: "fill-mask",
});
Expand Down
12 changes: 8 additions & 4 deletions packages/inference/src/tasks/nlp/questionAnswering.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@ export async function questionAnswering(
options?: Options
): Promise<QuestionAnsweringOutput[number]> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "question-answering");
const { data: res } = await innerRequest<QuestionAnsweringOutput | QuestionAnsweringOutput[number]>(args, {
...options,
task: "question-answering",
});
const { data: res } = await innerRequest<QuestionAnsweringOutput | QuestionAnsweringOutput[number]>(
args,
providerHelper,
{
...options,
task: "question-answering",
}
);
return providerHelper.getResponse(res);
}
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/nlp/sentenceSimilarity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export async function sentenceSimilarity(
options?: Options
): Promise<SentenceSimilarityOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "sentence-similarity");
const { data: res } = await innerRequest<SentenceSimilarityOutput>(args, {
const { data: res } = await innerRequest<SentenceSimilarityOutput>(args, providerHelper, {
...options,
task: "sentence-similarity",
});
Expand Down
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/nlp/summarization.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export type SummarizationArgs = BaseArgs & SummarizationInput;
*/
export async function summarization(args: SummarizationArgs, options?: Options): Promise<SummarizationOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "summarization");
const { data: res } = await innerRequest<SummarizationOutput[]>(args, {
const { data: res } = await innerRequest<SummarizationOutput[]>(args, providerHelper, {
...options,
task: "summarization",
});
Expand Down
12 changes: 8 additions & 4 deletions packages/inference/src/tasks/nlp/tableQuestionAnswering.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@ export async function tableQuestionAnswering(
options?: Options
): Promise<TableQuestionAnsweringOutput[number]> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering");
const { data: res } = await innerRequest<TableQuestionAnsweringOutput | TableQuestionAnsweringOutput[number]>(args, {
...options,
task: "table-question-answering",
});
const { data: res } = await innerRequest<TableQuestionAnsweringOutput | TableQuestionAnsweringOutput[number]>(
args,
providerHelper,
{
...options,
task: "table-question-answering",
}
);
return providerHelper.getResponse(res);
}
2 changes: 1 addition & 1 deletion packages/inference/src/tasks/nlp/textClassification.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ export async function textClassification(
options?: Options
): Promise<TextClassificationOutput> {
const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification");
const { data: res } = await innerRequest<TextClassificationOutput>(args, {
const { data: res } = await innerRequest<TextClassificationOutput>(args, providerHelper, {
...options,
task: "text-classification",
});
Expand Down
Loading