From f0d27d92a80f6090f37c6ac23a8093fe371b841f Mon Sep 17 00:00:00 2001 From: SBrandeis Date: Fri, 11 Apr 2025 14:04:56 +0200 Subject: [PATCH 1/3] top-down injection of providerHelper --- packages/inference/src/lib/makeRequestOptions.ts | 11 ++++++----- packages/inference/src/providers/hf-inference.ts | 12 ++++++------ .../inference/src/snippets/getInferenceSnippets.ts | 9 +++++++++ packages/inference/src/tasks/audio/audioToAudio.ts | 2 +- .../src/tasks/audio/automaticSpeechRecognition.ts | 2 +- packages/inference/src/tasks/audio/textToSpeech.ts | 2 +- packages/inference/src/tasks/custom/request.ts | 4 +++- .../inference/src/tasks/cv/imageClassification.ts | 2 +- .../inference/src/tasks/cv/imageSegmentation.ts | 2 +- packages/inference/src/tasks/cv/imageToImage.ts | 2 +- packages/inference/src/tasks/cv/imageToText.ts | 2 +- packages/inference/src/tasks/cv/objectDetection.ts | 2 +- packages/inference/src/tasks/cv/textToImage.ts | 4 ++-- packages/inference/src/tasks/cv/textToVideo.ts | 14 +++++++++----- .../src/tasks/cv/zeroShotImageClassification.ts | 2 +- .../tasks/multimodal/documentQuestionAnswering.ts | 1 + .../tasks/multimodal/visualQuestionAnswering.ts | 2 +- packages/inference/src/tasks/nlp/chatCompletion.ts | 2 +- .../inference/src/tasks/nlp/featureExtraction.ts | 2 +- packages/inference/src/tasks/nlp/fillMask.ts | 2 +- .../inference/src/tasks/nlp/questionAnswering.ts | 12 ++++++++---- .../inference/src/tasks/nlp/sentenceSimilarity.ts | 2 +- packages/inference/src/tasks/nlp/summarization.ts | 2 +- .../src/tasks/nlp/tableQuestionAnswering.ts | 12 ++++++++---- .../inference/src/tasks/nlp/textClassification.ts | 2 +- packages/inference/src/tasks/nlp/textGeneration.ts | 2 +- .../inference/src/tasks/nlp/tokenClassification.ts | 12 ++++++++---- packages/inference/src/tasks/nlp/translation.ts | 2 +- .../src/tasks/nlp/zeroShotClassification.ts | 12 ++++++++---- .../src/tasks/tabular/tabularClassification.ts | 2 +- .../src/tasks/tabular/tabularRegression.ts | 2 +- packages/inference/src/utils/request.ts | 11 +++++++---- 32 files changed, 95 insertions(+), 59 deletions(-) diff --git a/packages/inference/src/lib/makeRequestOptions.ts b/packages/inference/src/lib/makeRequestOptions.ts index 85bfd9263d..634a2a32fc 100644 --- a/packages/inference/src/lib/makeRequestOptions.ts +++ b/packages/inference/src/lib/makeRequestOptions.ts @@ -1,7 +1,7 @@ import { name as packageName, version as packageVersion } from "../../package.json"; import { HF_HEADER_X_BILL_TO, HF_HUB_URL } from "../config"; import type { InferenceTask, Options, RequestArgs } from "../types"; -import { getProviderHelper } from "./getProviderHelper"; +import type { getProviderHelper } from "./getProviderHelper"; import { getProviderModelId } from "./getProviderModelId"; import { isUrl } from "./isUrl"; @@ -20,6 +20,7 @@ export async function makeRequestOptions( data?: Blob | ArrayBuffer; stream?: boolean; }, + providerHelper: ReturnType, options?: Options & { /** In most cases (unless we pass a endpointUrl) we know the task */ task?: InferenceTask; @@ -28,6 +29,7 @@ export async function makeRequestOptions( const { provider: maybeProvider, model: maybeModel } = args; const provider = maybeProvider ?? "hf-inference"; const { task } = options ?? {}; + // Validate inputs if (args.endpointUrl && provider !== "hf-inference") { throw new Error(`Cannot use endpointUrl with a third-party provider.`); @@ -38,7 +40,7 @@ export async function makeRequestOptions( if (args.endpointUrl) { // No need to have maybeModel, or to load default model for a task - return makeRequestOptionsFromResolvedModel(maybeModel ?? args.endpointUrl, args, options); + return makeRequestOptionsFromResolvedModel(maybeModel ?? args.endpointUrl, providerHelper, args, options); } if (!maybeModel && !task) { @@ -47,7 +49,6 @@ export async function makeRequestOptions( // eslint-disable-next-line @typescript-eslint/no-non-null-assertion const hfModel = maybeModel ?? (await loadDefaultModel(task!)); - const providerHelper = getProviderHelper(provider, task); if (providerHelper.clientSideRoutingOnly && !maybeModel) { throw new Error(`Provider ${provider} requires a model ID to be passed directly.`); @@ -62,7 +63,7 @@ export async function makeRequestOptions( }); // Use the sync version with the resolved model - return makeRequestOptionsFromResolvedModel(resolvedModel, args, options); + return makeRequestOptionsFromResolvedModel(resolvedModel, providerHelper, args, options); } /** @@ -71,6 +72,7 @@ export async function makeRequestOptions( */ export function makeRequestOptionsFromResolvedModel( resolvedModel: string, + providerHelper: ReturnType, args: RequestArgs & { data?: Blob | ArrayBuffer; stream?: boolean; @@ -85,7 +87,6 @@ export function makeRequestOptionsFromResolvedModel( const provider = maybeProvider ?? "hf-inference"; const { includeCredentials, task, signal, billTo } = options ?? {}; - const providerHelper = getProviderHelper(provider, task); const authMethod = (() => { if (providerHelper.clientSideRoutingOnly) { // Closed-source providers require an accessToken (cannot be routed). diff --git a/packages/inference/src/providers/hf-inference.ts b/packages/inference/src/providers/hf-inference.ts index 0d01a2e794..f014c7815c 100644 --- a/packages/inference/src/providers/hf-inference.ts +++ b/packages/inference/src/providers/hf-inference.ts @@ -385,13 +385,13 @@ export class HFInferenceQuestionAnsweringTask extends HFInferenceTask implements typeof elem.end === "number" && typeof elem.score === "number" && typeof elem.start === "number" - ) + ) : typeof response === "object" && - !!response && - typeof response.answer === "string" && - typeof response.end === "number" && - typeof response.score === "number" && - typeof response.start === "number" + !!response && + typeof response.answer === "string" && + typeof response.end === "number" && + typeof response.score === "number" && + typeof response.start === "number" ) { return Array.isArray(response) ? response[0] : response; } diff --git a/packages/inference/src/snippets/getInferenceSnippets.ts b/packages/inference/src/snippets/getInferenceSnippets.ts index e8a5ce4025..ce2b75862c 100644 --- a/packages/inference/src/snippets/getInferenceSnippets.ts +++ b/packages/inference/src/snippets/getInferenceSnippets.ts @@ -11,6 +11,7 @@ import type { ChatCompletionInputMessage, GenerationParameters } from "@huggingf import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions"; import type { InferenceProvider, InferenceTask, RequestArgs } from "../types"; import { templates } from "./templates.exported"; +import { getProviderHelper } from "../lib/getProviderHelper"; const PYTHON_CLIENTS = ["huggingface_hub", "fal_client", "requests", "openai"] as const; const JS_CLIENTS = ["fetch", "huggingface.js", "openai"] as const; @@ -130,10 +131,18 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar inputPreparationFn = prepareConversationalInput; task = "conversational"; } + let providerHelper: ReturnType; + try { + providerHelper = getProviderHelper(provider, task); + } catch (e) { + console.error(`Failed to get provider helper for ${provider} (${task})`, e); + return []; + } /// Prepare inputs + make request const inputs = inputPreparationFn ? inputPreparationFn(model, opts) : { inputs: getModelInputSnippet(model) }; const request = makeRequestOptionsFromResolvedModel( providerModelId ?? model.id, + providerHelper, { accessToken: accessToken, provider: provider, diff --git a/packages/inference/src/tasks/audio/audioToAudio.ts b/packages/inference/src/tasks/audio/audioToAudio.ts index f93fe60951..fa055d3234 100644 --- a/packages/inference/src/tasks/audio/audioToAudio.ts +++ b/packages/inference/src/tasks/audio/audioToAudio.ts @@ -38,7 +38,7 @@ export interface AudioToAudioOutput { export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise { const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-to-audio"); const payload = preparePayload(args); - const { data: res } = await innerRequest(payload, { + const { data: res } = await innerRequest(payload, providerHelper, { ...options, task: "audio-to-audio", }); diff --git a/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts b/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts index 0f0c99339a..c71af07427 100644 --- a/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts +++ b/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts @@ -20,7 +20,7 @@ export async function automaticSpeechRecognition( ): Promise { const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "automatic-speech-recognition"); const payload = await buildPayload(args); - const { data: res } = await innerRequest(payload, { + const { data: res } = await innerRequest(payload, providerHelper, { ...options, task: "automatic-speech-recognition", }); diff --git a/packages/inference/src/tasks/audio/textToSpeech.ts b/packages/inference/src/tasks/audio/textToSpeech.ts index 6c33cb882f..11f01f436f 100644 --- a/packages/inference/src/tasks/audio/textToSpeech.ts +++ b/packages/inference/src/tasks/audio/textToSpeech.ts @@ -14,7 +14,7 @@ interface OutputUrlTextToSpeechGeneration { export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise { const provider = args.provider ?? "hf-inference"; const providerHelper = getProviderHelper(provider, "text-to-speech"); - const { data: res } = await innerRequest(args, { + const { data: res } = await innerRequest(args, providerHelper, { ...options, task: "text-to-speech", }); diff --git a/packages/inference/src/tasks/custom/request.ts b/packages/inference/src/tasks/custom/request.ts index 15828acca6..62f45f28b3 100644 --- a/packages/inference/src/tasks/custom/request.ts +++ b/packages/inference/src/tasks/custom/request.ts @@ -1,3 +1,4 @@ +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { InferenceTask, Options, RequestArgs } from "../../types"; import { innerRequest } from "../../utils/request"; @@ -15,6 +16,7 @@ export async function request( console.warn( "The request method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead." ); - const result = await innerRequest(args, options); + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task); + const result = await innerRequest(args, providerHelper, options); return result.data; } diff --git a/packages/inference/src/tasks/cv/imageClassification.ts b/packages/inference/src/tasks/cv/imageClassification.ts index a64b655ee6..e683ecb3e2 100644 --- a/packages/inference/src/tasks/cv/imageClassification.ts +++ b/packages/inference/src/tasks/cv/imageClassification.ts @@ -16,7 +16,7 @@ export async function imageClassification( ): Promise { const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-classification"); const payload = preparePayload(args); - const { data: res } = await innerRequest(payload, { + const { data: res } = await innerRequest(payload, providerHelper, { ...options, task: "image-classification", }); diff --git a/packages/inference/src/tasks/cv/imageSegmentation.ts b/packages/inference/src/tasks/cv/imageSegmentation.ts index d8da9ef1e3..9de0e9a2ef 100644 --- a/packages/inference/src/tasks/cv/imageSegmentation.ts +++ b/packages/inference/src/tasks/cv/imageSegmentation.ts @@ -16,7 +16,7 @@ export async function imageSegmentation( ): Promise { const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-segmentation"); const payload = preparePayload(args); - const { data: res } = await innerRequest(payload, { + const { data: res } = await innerRequest(payload, providerHelper, { ...options, task: "image-segmentation", }); diff --git a/packages/inference/src/tasks/cv/imageToImage.ts b/packages/inference/src/tasks/cv/imageToImage.ts index d97b57a1dd..49d8ca2be5 100644 --- a/packages/inference/src/tasks/cv/imageToImage.ts +++ b/packages/inference/src/tasks/cv/imageToImage.ts @@ -27,7 +27,7 @@ export async function imageToImage(args: ImageToImageArgs, options?: Options): P ), }; } - const { data: res } = await innerRequest(reqArgs, { + const { data: res } = await innerRequest(reqArgs, providerHelper, { ...options, task: "image-to-image", }); diff --git a/packages/inference/src/tasks/cv/imageToText.ts b/packages/inference/src/tasks/cv/imageToText.ts index 5132f1bb19..cdee706fa4 100644 --- a/packages/inference/src/tasks/cv/imageToText.ts +++ b/packages/inference/src/tasks/cv/imageToText.ts @@ -12,7 +12,7 @@ export type ImageToTextArgs = BaseArgs & (ImageToTextInput | LegacyImageInput); export async function imageToText(args: ImageToTextArgs, options?: Options): Promise { const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "image-to-text"); const payload = preparePayload(args); - const { data: res } = await innerRequest<[ImageToTextOutput]>(payload, { + const { data: res } = await innerRequest<[ImageToTextOutput]>(payload, providerHelper, { ...options, task: "image-to-text", }); diff --git a/packages/inference/src/tasks/cv/objectDetection.ts b/packages/inference/src/tasks/cv/objectDetection.ts index 2ec6e5bca4..d103feeb96 100644 --- a/packages/inference/src/tasks/cv/objectDetection.ts +++ b/packages/inference/src/tasks/cv/objectDetection.ts @@ -13,7 +13,7 @@ export type ObjectDetectionArgs = BaseArgs & (ObjectDetectionInput | LegacyImage export async function objectDetection(args: ObjectDetectionArgs, options?: Options): Promise { const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "object-detection"); const payload = preparePayload(args); - const { data: res } = await innerRequest(payload, { + const { data: res } = await innerRequest(payload, providerHelper, { ...options, task: "object-detection", }); diff --git a/packages/inference/src/tasks/cv/textToImage.ts b/packages/inference/src/tasks/cv/textToImage.ts index a03edc6afd..490a8e10b6 100644 --- a/packages/inference/src/tasks/cv/textToImage.ts +++ b/packages/inference/src/tasks/cv/textToImage.ts @@ -25,11 +25,11 @@ export async function textToImage( export async function textToImage(args: TextToImageArgs, options?: TextToImageOptions): Promise { const provider = args.provider ?? "hf-inference"; const providerHelper = getProviderHelper(provider, "text-to-image"); - const { data: res } = await innerRequest>(args, { + const { data: res } = await innerRequest>(args, providerHelper, { ...options, task: "text-to-image", }); - const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-image" }); + const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "text-to-image" }); return providerHelper.getResponse(res, url, info.headers as Record, options?.outputType); } diff --git a/packages/inference/src/tasks/cv/textToVideo.ts b/packages/inference/src/tasks/cv/textToVideo.ts index d032554aed..9143e147a8 100644 --- a/packages/inference/src/tasks/cv/textToVideo.ts +++ b/packages/inference/src/tasks/cv/textToVideo.ts @@ -14,10 +14,14 @@ export type TextToVideoOutput = Blob; export async function textToVideo(args: TextToVideoArgs, options?: Options): Promise { const provider = args.provider ?? "hf-inference"; const providerHelper = getProviderHelper(provider, "text-to-video"); - const { data: response } = await innerRequest(args, { - ...options, - task: "text-to-video", - }); - const { url, info } = await makeRequestOptions(args, { ...options, task: "text-to-video" }); + const { data: response } = await innerRequest( + args, + providerHelper, + { + ...options, + task: "text-to-video", + } + ); + const { url, info } = await makeRequestOptions(args, providerHelper, { ...options, task: "text-to-video" }); return providerHelper.getResponse(response, url, info.headers as Record); } diff --git a/packages/inference/src/tasks/cv/zeroShotImageClassification.ts b/packages/inference/src/tasks/cv/zeroShotImageClassification.ts index d317f3ff77..8f09e74aac 100644 --- a/packages/inference/src/tasks/cv/zeroShotImageClassification.ts +++ b/packages/inference/src/tasks/cv/zeroShotImageClassification.ts @@ -46,7 +46,7 @@ export async function zeroShotImageClassification( ): Promise { const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-image-classification"); const payload = await preparePayload(args); - const { data: res } = await innerRequest(payload, { + const { data: res } = await innerRequest(payload, providerHelper, { ...options, task: "zero-shot-image-classification", }); diff --git a/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts b/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts index 6702319dcd..033708dce0 100644 --- a/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts +++ b/packages/inference/src/tasks/multimodal/documentQuestionAnswering.ts @@ -30,6 +30,7 @@ export async function documentQuestionAnswering( } as RequestArgs; const { data: res } = await innerRequest( reqArgs, + providerHelper, { ...options, task: "document-question-answering", diff --git a/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts b/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts index 8c795e336f..98a6616206 100644 --- a/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts +++ b/packages/inference/src/tasks/multimodal/visualQuestionAnswering.ts @@ -29,7 +29,7 @@ export async function visualQuestionAnswering( }, } as RequestArgs; - const { data: res } = await innerRequest(reqArgs, { + const { data: res } = await innerRequest(reqArgs, providerHelper, { ...options, task: "visual-question-answering", }); diff --git a/packages/inference/src/tasks/nlp/chatCompletion.ts b/packages/inference/src/tasks/nlp/chatCompletion.ts index 3467f8b179..4ad9be5f1c 100644 --- a/packages/inference/src/tasks/nlp/chatCompletion.ts +++ b/packages/inference/src/tasks/nlp/chatCompletion.ts @@ -11,7 +11,7 @@ export async function chatCompletion( options?: Options ): Promise { const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational"); - const { data: response } = await innerRequest(args, { + const { data: response } = await innerRequest(args, providerHelper, { ...options, task: "conversational", }); diff --git a/packages/inference/src/tasks/nlp/featureExtraction.ts b/packages/inference/src/tasks/nlp/featureExtraction.ts index 03980830dc..6dfe0bd382 100644 --- a/packages/inference/src/tasks/nlp/featureExtraction.ts +++ b/packages/inference/src/tasks/nlp/featureExtraction.ts @@ -18,7 +18,7 @@ export async function featureExtraction( options?: Options ): Promise { const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "feature-extraction"); - const { data: res } = await innerRequest(args, { + const { data: res } = await innerRequest(args, providerHelper, { ...options, task: "feature-extraction", }); diff --git a/packages/inference/src/tasks/nlp/fillMask.ts b/packages/inference/src/tasks/nlp/fillMask.ts index 061a3aee7b..663db87d99 100644 --- a/packages/inference/src/tasks/nlp/fillMask.ts +++ b/packages/inference/src/tasks/nlp/fillMask.ts @@ -10,7 +10,7 @@ export type FillMaskArgs = BaseArgs & FillMaskInput; */ export async function fillMask(args: FillMaskArgs, options?: Options): Promise { const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "fill-mask"); - const { data: res } = await innerRequest(args, { + const { data: res } = await innerRequest(args, providerHelper, { ...options, task: "fill-mask", }); diff --git a/packages/inference/src/tasks/nlp/questionAnswering.ts b/packages/inference/src/tasks/nlp/questionAnswering.ts index 64ea84c0a6..6559c80c1a 100644 --- a/packages/inference/src/tasks/nlp/questionAnswering.ts +++ b/packages/inference/src/tasks/nlp/questionAnswering.ts @@ -13,9 +13,13 @@ export async function questionAnswering( options?: Options ): Promise { const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "question-answering"); - const { data: res } = await innerRequest(args, { - ...options, - task: "question-answering", - }); + const { data: res } = await innerRequest( + args, + providerHelper, + { + ...options, + task: "question-answering", + } + ); return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/nlp/sentenceSimilarity.ts b/packages/inference/src/tasks/nlp/sentenceSimilarity.ts index 920b322283..faa751f73e 100644 --- a/packages/inference/src/tasks/nlp/sentenceSimilarity.ts +++ b/packages/inference/src/tasks/nlp/sentenceSimilarity.ts @@ -13,7 +13,7 @@ export async function sentenceSimilarity( options?: Options ): Promise { const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "sentence-similarity"); - const { data: res } = await innerRequest(args, { + const { data: res } = await innerRequest(args, providerHelper, { ...options, task: "sentence-similarity", }); diff --git a/packages/inference/src/tasks/nlp/summarization.ts b/packages/inference/src/tasks/nlp/summarization.ts index 78156ee473..4b4205bf4b 100644 --- a/packages/inference/src/tasks/nlp/summarization.ts +++ b/packages/inference/src/tasks/nlp/summarization.ts @@ -10,7 +10,7 @@ export type SummarizationArgs = BaseArgs & SummarizationInput; */ export async function summarization(args: SummarizationArgs, options?: Options): Promise { const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "summarization"); - const { data: res } = await innerRequest(args, { + const { data: res } = await innerRequest(args, providerHelper, { ...options, task: "summarization", }); diff --git a/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts b/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts index f2749573dc..3939115862 100644 --- a/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts +++ b/packages/inference/src/tasks/nlp/tableQuestionAnswering.ts @@ -13,9 +13,13 @@ export async function tableQuestionAnswering( options?: Options ): Promise { const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "table-question-answering"); - const { data: res } = await innerRequest(args, { - ...options, - task: "table-question-answering", - }); + const { data: res } = await innerRequest( + args, + providerHelper, + { + ...options, + task: "table-question-answering", + } + ); return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/nlp/textClassification.ts b/packages/inference/src/tasks/nlp/textClassification.ts index 336e156836..7631d82286 100644 --- a/packages/inference/src/tasks/nlp/textClassification.ts +++ b/packages/inference/src/tasks/nlp/textClassification.ts @@ -13,7 +13,7 @@ export async function textClassification( options?: Options ): Promise { const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-classification"); - const { data: res } = await innerRequest(args, { + const { data: res } = await innerRequest(args, providerHelper, { ...options, task: "text-classification", }); diff --git a/packages/inference/src/tasks/nlp/textGeneration.ts b/packages/inference/src/tasks/nlp/textGeneration.ts index 48560cd3c8..c2266c7968 100644 --- a/packages/inference/src/tasks/nlp/textGeneration.ts +++ b/packages/inference/src/tasks/nlp/textGeneration.ts @@ -17,7 +17,7 @@ export async function textGeneration( const providerHelper = getProviderHelper(provider, "text-generation"); const { data: response } = await innerRequest< HyperbolicTextCompletionOutput | TextGenerationOutput | TextGenerationOutput[] - >(args, { + >(args, providerHelper, { ...options, task: "text-generation", }); diff --git a/packages/inference/src/tasks/nlp/tokenClassification.ts b/packages/inference/src/tasks/nlp/tokenClassification.ts index 0a21569b38..0c52b9e6a6 100644 --- a/packages/inference/src/tasks/nlp/tokenClassification.ts +++ b/packages/inference/src/tasks/nlp/tokenClassification.ts @@ -13,9 +13,13 @@ export async function tokenClassification( options?: Options ): Promise { const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "token-classification"); - const { data: res } = await innerRequest(args, { - ...options, - task: "token-classification", - }); + const { data: res } = await innerRequest( + args, + providerHelper, + { + ...options, + task: "token-classification", + } + ); return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/nlp/translation.ts b/packages/inference/src/tasks/nlp/translation.ts index 2d3576c40a..1f463e3e67 100644 --- a/packages/inference/src/tasks/nlp/translation.ts +++ b/packages/inference/src/tasks/nlp/translation.ts @@ -9,7 +9,7 @@ export type TranslationArgs = BaseArgs & TranslationInput; */ export async function translation(args: TranslationArgs, options?: Options): Promise { const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "translation"); - const { data: res } = await innerRequest(args, { + const { data: res } = await innerRequest(args, providerHelper, { ...options, task: "translation", }); diff --git a/packages/inference/src/tasks/nlp/zeroShotClassification.ts b/packages/inference/src/tasks/nlp/zeroShotClassification.ts index 4877dfaea8..30d6d0c156 100644 --- a/packages/inference/src/tasks/nlp/zeroShotClassification.ts +++ b/packages/inference/src/tasks/nlp/zeroShotClassification.ts @@ -13,9 +13,13 @@ export async function zeroShotClassification( options?: Options ): Promise { const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "zero-shot-classification"); - const { data: res } = await innerRequest(args, { - ...options, - task: "zero-shot-classification", - }); + const { data: res } = await innerRequest( + args, + providerHelper, + { + ...options, + task: "zero-shot-classification", + } + ); return providerHelper.getResponse(res); } diff --git a/packages/inference/src/tasks/tabular/tabularClassification.ts b/packages/inference/src/tasks/tabular/tabularClassification.ts index bdc9ab6008..9174c17718 100644 --- a/packages/inference/src/tasks/tabular/tabularClassification.ts +++ b/packages/inference/src/tasks/tabular/tabularClassification.ts @@ -26,7 +26,7 @@ export async function tabularClassification( options?: Options ): Promise { const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-classification"); - const { data: res } = await innerRequest(args, { + const { data: res } = await innerRequest(args, providerHelper, { ...options, task: "tabular-classification", }); diff --git a/packages/inference/src/tasks/tabular/tabularRegression.ts b/packages/inference/src/tasks/tabular/tabularRegression.ts index 5e691e5c5c..2c2408ffde 100644 --- a/packages/inference/src/tasks/tabular/tabularRegression.ts +++ b/packages/inference/src/tasks/tabular/tabularRegression.ts @@ -26,7 +26,7 @@ export async function tabularRegression( options?: Options ): Promise { const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "tabular-regression"); - const { data: res } = await innerRequest(args, { + const { data: res } = await innerRequest(args, providerHelper, { ...options, task: "tabular-regression", }); diff --git a/packages/inference/src/utils/request.ts b/packages/inference/src/utils/request.ts index 8753f019f7..3af8a128a0 100644 --- a/packages/inference/src/utils/request.ts +++ b/packages/inference/src/utils/request.ts @@ -1,3 +1,4 @@ +import type { getProviderHelper } from "../lib/getProviderHelper"; import { makeRequestOptions } from "../lib/makeRequestOptions"; import type { InferenceTask, Options, RequestArgs } from "../types"; import type { EventSourceMessage } from "../vendor/fetch-event-source/parse"; @@ -16,6 +17,7 @@ export interface ResponseWrapper { */ export async function innerRequest( args: RequestArgs, + providerHelper: ReturnType, options?: Options & { /** In most cases (unless we pass a endpointUrl) we know the task */ task?: InferenceTask; @@ -23,13 +25,13 @@ export async function innerRequest( chatCompletion?: boolean; } ): Promise> { - const { url, info } = await makeRequestOptions(args, options); + const { url, info } = await makeRequestOptions(args, providerHelper, options); const response = await (options?.fetch ?? fetch)(url, info); const requestContext: ResponseWrapper["requestContext"] = { url, info }; if (options?.retry_on_error !== false && response.status === 503) { - return innerRequest(args, options); + return innerRequest(args, providerHelper, options); } if (!response.ok) { @@ -65,6 +67,7 @@ export async function innerRequest( */ export async function* innerStreamingRequest( args: RequestArgs, + providerHelper: ReturnType, options?: Options & { /** In most cases (unless we pass a endpointUrl) we know the task */ task?: InferenceTask; @@ -72,11 +75,11 @@ export async function* innerStreamingRequest( chatCompletion?: boolean; } ): AsyncGenerator { - const { url, info } = await makeRequestOptions({ ...args, stream: true }, options); + const { url, info } = await makeRequestOptions({ ...args, stream: true }, providerHelper, options); const response = await (options?.fetch ?? fetch)(url, info); if (options?.retry_on_error !== false && response.status === 503) { - return yield* innerStreamingRequest(args, options); + return yield* innerStreamingRequest(args, providerHelper, options); } if (!response.ok) { if (response.headers.get("Content-Type")?.startsWith("application/json")) { From 179b9de1aff7022090156158d6828237b02f7378 Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Fri, 11 Apr 2025 15:16:38 +0200 Subject: [PATCH 2/3] update remaining tasks --- packages/inference/src/tasks/audio/audioClassification.ts | 2 +- packages/inference/src/tasks/custom/streamingRequest.ts | 5 ++++- packages/inference/src/tasks/nlp/chatCompletionStream.ts | 5 +++-- packages/inference/src/tasks/nlp/textGeneration.ts | 3 +-- packages/inference/src/tasks/nlp/textGenerationStream.ts | 4 +++- 5 files changed, 12 insertions(+), 7 deletions(-) diff --git a/packages/inference/src/tasks/audio/audioClassification.ts b/packages/inference/src/tasks/audio/audioClassification.ts index 25ad508158..f1ff1c20c7 100644 --- a/packages/inference/src/tasks/audio/audioClassification.ts +++ b/packages/inference/src/tasks/audio/audioClassification.ts @@ -17,7 +17,7 @@ export async function audioClassification( ): Promise { const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "audio-classification"); const payload = preparePayload(args); - const { data: res } = await innerRequest(payload, { + const { data: res } = await innerRequest(payload, providerHelper, { ...options, task: "audio-classification", }); diff --git a/packages/inference/src/tasks/custom/streamingRequest.ts b/packages/inference/src/tasks/custom/streamingRequest.ts index dd11858c4e..45ae99f323 100644 --- a/packages/inference/src/tasks/custom/streamingRequest.ts +++ b/packages/inference/src/tasks/custom/streamingRequest.ts @@ -1,5 +1,7 @@ +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { InferenceTask, Options, RequestArgs } from "../../types"; import { innerStreamingRequest } from "../../utils/request"; + /** * Primitive to make custom inference calls that expect server-sent events, and returns the response through a generator * @deprecated Use specific task functions instead. This function will be removed in a future version. @@ -14,5 +16,6 @@ export async function* streamingRequest( console.warn( "The streamingRequest method is deprecated and will be removed in a future version of huggingface.js. Use specific task functions instead." ); - yield* innerStreamingRequest(args, options); + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", options?.task); + yield* innerStreamingRequest(args, providerHelper, options); } diff --git a/packages/inference/src/tasks/nlp/chatCompletionStream.ts b/packages/inference/src/tasks/nlp/chatCompletionStream.ts index 64392210b5..2b46e729d9 100644 --- a/packages/inference/src/tasks/nlp/chatCompletionStream.ts +++ b/packages/inference/src/tasks/nlp/chatCompletionStream.ts @@ -1,4 +1,5 @@ import type { ChatCompletionInput, ChatCompletionStreamOutput } from "@huggingface/tasks"; +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerStreamingRequest } from "../../utils/request"; @@ -9,8 +10,8 @@ export async function* chatCompletionStream( args: BaseArgs & ChatCompletionInput, options?: Options ): AsyncGenerator { - yield* innerStreamingRequest(args, { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational"); + yield* innerStreamingRequest(args, providerHelper, { ...options, - task: "conversational", }); } diff --git a/packages/inference/src/tasks/nlp/textGeneration.ts b/packages/inference/src/tasks/nlp/textGeneration.ts index c2266c7968..5d84e543cb 100644 --- a/packages/inference/src/tasks/nlp/textGeneration.ts +++ b/packages/inference/src/tasks/nlp/textGeneration.ts @@ -13,8 +13,7 @@ export async function textGeneration( args: BaseArgs & TextGenerationInput, options?: Options ): Promise { - const provider = args.provider ?? "hf-inference"; - const providerHelper = getProviderHelper(provider, "text-generation"); + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation"); const { data: response } = await innerRequest< HyperbolicTextCompletionOutput | TextGenerationOutput | TextGenerationOutput[] >(args, providerHelper, { diff --git a/packages/inference/src/tasks/nlp/textGenerationStream.ts b/packages/inference/src/tasks/nlp/textGenerationStream.ts index de6d84e72c..59706eaa5c 100644 --- a/packages/inference/src/tasks/nlp/textGenerationStream.ts +++ b/packages/inference/src/tasks/nlp/textGenerationStream.ts @@ -1,4 +1,5 @@ import type { TextGenerationInput } from "@huggingface/tasks"; +import { getProviderHelper } from "../../lib/getProviderHelper"; import type { BaseArgs, Options } from "../../types"; import { innerStreamingRequest } from "../../utils/request"; @@ -89,7 +90,8 @@ export async function* textGenerationStream( args: BaseArgs & TextGenerationInput, options?: Options ): AsyncGenerator { - yield* innerStreamingRequest(args, { + const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "text-generation"); + yield* innerStreamingRequest(args, providerHelper, { ...options, task: "text-generation", }); From 3161239038d7b5610fee1abd03fb290e73930ecf Mon Sep 17 00:00:00 2001 From: Celina Hanouti Date: Fri, 11 Apr 2025 15:27:05 +0200 Subject: [PATCH 3/3] fix --- packages/inference/src/tasks/nlp/chatCompletionStream.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/inference/src/tasks/nlp/chatCompletionStream.ts b/packages/inference/src/tasks/nlp/chatCompletionStream.ts index 2b46e729d9..cf88044112 100644 --- a/packages/inference/src/tasks/nlp/chatCompletionStream.ts +++ b/packages/inference/src/tasks/nlp/chatCompletionStream.ts @@ -13,5 +13,6 @@ export async function* chatCompletionStream( const providerHelper = getProviderHelper(args.provider ?? "hf-inference", "conversational"); yield* innerStreamingRequest(args, providerHelper, { ...options, + task: "conversational", }); }