Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 7 additions & 22 deletions packages/inference/src/tasks/audio/audioClassification.ts
Original file line number Diff line number Diff line change
@@ -1,27 +1,11 @@
import type { AudioClassificationInput, AudioClassificationOutput } from "@huggingface/tasks";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";
import type { LegacyAudioInput } from "./utils";
import { preparePayload } from "./utils";

export type AudioClassificationArgs = BaseArgs & {
/**
* Binary audio data
*/
data: Blob | ArrayBuffer;
};

export interface AudioClassificationOutputValue {
/**
* The label for the class (model specific)
*/
label: string;

/**
* A float that represents how likely it is that the audio file belongs to this class.
*/
score: number;
}

export type AudioClassificationReturn = AudioClassificationOutputValue[];
export type AudioClassificationArgs = BaseArgs & (AudioClassificationInput | LegacyAudioInput);

/**
* This task reads some audio input and outputs the likelihood of classes.
Expand All @@ -30,8 +14,9 @@ export type AudioClassificationReturn = AudioClassificationOutputValue[];
export async function audioClassification(
args: AudioClassificationArgs,
options?: Options
): Promise<AudioClassificationReturn> {
const res = await request<AudioClassificationReturn>(args, {
): Promise<AudioClassificationOutput> {
const payload = preparePayload(args);
const res = await request<AudioClassificationOutput>(payload, {
...options,
taskHint: "audio-classification",
});
Expand Down
66 changes: 43 additions & 23 deletions packages/inference/src/tasks/audio/audioToAudio.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";
import type { LegacyAudioInput } from "./utils";
import { preparePayload } from "./utils";

export type AudioToAudioArgs = BaseArgs & {
/**
* Binary audio data
*/
data: Blob | ArrayBuffer;
};
export type AudioToAudioArgs =
| (BaseArgs & {
/**
* Binary audio data
*/
inputs: Blob;
})
| LegacyAudioInput;

export interface AudioToAudioOutputValue {
export interface AudioToAudioOutputElem {
/**
* The label for the audio output (model specific)
*/
Expand All @@ -18,32 +22,48 @@ export interface AudioToAudioOutputValue {
/**
* Base64 encoded audio output.
*/
blob: string;
audio: Blob;
}

/**
* Content-type for blob, e.g. audio/flac
*/
export interface AudioToAudioOutput {
blob: string;
"content-type": string;
label: string;
}

export type AudioToAudioReturn = AudioToAudioOutputValue[];

/**
* This task reads some audio input and outputs one or multiple audio files.
* Example model: speechbrain/sepformer-wham does audio source separation.
*/
export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioReturn> {
const res = await request<AudioToAudioReturn>(args, {
export async function audioToAudio(args: AudioToAudioArgs, options?: Options): Promise<AudioToAudioOutput[]> {
const payload = preparePayload(args);
const res = await request<AudioToAudioOutput>(payload, {
...options,
taskHint: "audio-to-audio",
});
const isValidOutput =
Array.isArray(res) &&
res.every(
(x) => typeof x.label === "string" && typeof x.blob === "string" && typeof x["content-type"] === "string"
);
if (!isValidOutput) {
throw new InferenceOutputError("Expected Array<{label: string, blob: string, content-type: string}>");

return validateOutput(res);
}

function validateOutput(output: unknown): AudioToAudioOutput[] {
if (!Array.isArray(output)) {
throw new InferenceOutputError("Expected Array");
}
if (
!output.every((elem): elem is AudioToAudioOutput => {
return (
typeof elem === "object" &&
elem &&
"label" in elem &&
typeof elem.label === "string" &&
"content-type" in elem &&
typeof elem["content-type"] === "string" &&
"blob" in elem &&
typeof elem.blob === "string"
);
})
) {
throw new InferenceOutputError("Expected Array<{label: string, audio: Blob}>");
}
return res;
return output;
}
58 changes: 35 additions & 23 deletions packages/inference/src/tasks/audio/automaticSpeechRecognition.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,13 @@
import type { AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options, RequestArgs } from "../../types";
import { base64FromBytes } from "../../utils/base64FromBytes";
import { request } from "../custom/request";
import type { LegacyAudioInput } from "./utils";
import { preparePayload } from "./utils";
import { omit } from "../../utils/omit";

export type AutomaticSpeechRecognitionArgs = BaseArgs & {
/**
* Binary audio data
*/
data: Blob | ArrayBuffer;
};

export interface AutomaticSpeechRecognitionOutput {
/**
* The text that was recognized from the audio
*/
text: string;
}

export type AutomaticSpeechRecognitionArgs = BaseArgs & (AutomaticSpeechRecognitionInput | LegacyAudioInput);
/**
* This task reads some audio input and outputs the said words within the audio files.
* Recommended model (english language): facebook/wav2vec2-large-960h-lv60-self
Expand All @@ -25,15 +16,8 @@ export async function automaticSpeechRecognition(
args: AutomaticSpeechRecognitionArgs,
options?: Options
): Promise<AutomaticSpeechRecognitionOutput> {
if (args.provider === "fal-ai") {
const contentType = args.data instanceof Blob ? args.data.type : "audio/mpeg";
const base64audio = base64FromBytes(
new Uint8Array(args.data instanceof ArrayBuffer ? args.data : await args.data.arrayBuffer())
);
(args as RequestArgs & { audio_url: string }).audio_url = `data:${contentType};base64,${base64audio}`;
delete (args as RequestArgs & { data: unknown }).data;
}
const res = await request<AutomaticSpeechRecognitionOutput>(args, {
const payload = await buildPayload(args);
const res = await request<AutomaticSpeechRecognitionOutput>(payload, {
...options,
taskHint: "automatic-speech-recognition",
});
Expand All @@ -43,3 +27,31 @@ export async function automaticSpeechRecognition(
}
return res;
}

const FAL_AI_SUPPORTED_BLOB_TYPES = ["audio/mpeg", "audio/mp4", "audio/wav", "audio/x-wav"];

async function buildPayload(args: AutomaticSpeechRecognitionArgs): Promise<RequestArgs> {
if (args.provider === "fal-ai") {
const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : undefined;
const contentType = blob?.type;
if (!contentType) {
throw new Error(
`Unable to determine the input's content-type. Make sure your are passing a Blob when using provider fal-ai.`
);
}
if (!FAL_AI_SUPPORTED_BLOB_TYPES.includes(contentType)) {
throw new Error(
`Provider fal-ai does not support blob type ${contentType} - supported content types are: ${FAL_AI_SUPPORTED_BLOB_TYPES.join(
", "
)}`
);
}
const base64audio = base64FromBytes(new Uint8Array(await blob.arrayBuffer()));
return {
...("data" in args ? omit(args, "data") : omit(args, "inputs")),
audio_url: `data:${contentType};base64,${base64audio}`,
};
} else {
return preparePayload(args);
}
}
22 changes: 8 additions & 14 deletions packages/inference/src/tasks/audio/textToSpeech.ts
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
import type { TextToSpeechInput } from "@huggingface/tasks";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";

export type TextToSpeechArgs = BaseArgs & {
/**
* The text to generate an audio from
*/
inputs: string;
};
type TextToSpeechArgs = BaseArgs & TextToSpeechInput;

export type TextToSpeechOutput = Blob;
interface OutputUrlTextToSpeechGeneration {
output: string | string[];
}
/**
* This task synthesize an audio of a voice pronouncing a given text.
* Recommended model: espnet/kan-bayashi_ljspeech_vits
*/
export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<TextToSpeechOutput> {
const res = await request<TextToSpeechOutput | OutputUrlTextToSpeechGeneration>(args, {
export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> {
const res = await request<Blob | OutputUrlTextToSpeechGeneration>(args, {
...options,
taskHint: "text-to-speech",
});
if (res instanceof Blob) {
return res;
}
if (res && typeof res === "object") {
if ("output" in res) {
if (typeof res.output === "string") {
Expand All @@ -35,9 +33,5 @@ export async function textToSpeech(args: TextToSpeechArgs, options?: Options): P
}
}
}
const isValidOutput = res && res instanceof Blob;
if (!isValidOutput) {
throw new InferenceOutputError("Expected Blob");
}
return res;
throw new InferenceOutputError("Expected Blob or object with output");
}
18 changes: 18 additions & 0 deletions packages/inference/src/tasks/audio/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import type { BaseArgs, RequestArgs } from "../../types";
import { omit } from "../../utils/omit";

/**
* @deprecated
*/
export interface LegacyAudioInput {
data: Blob | ArrayBuffer;
}

export function preparePayload(args: BaseArgs & ({ inputs: Blob } | LegacyAudioInput)): RequestArgs {
return "data" in args
? args
: {
...omit(args, "inputs"),
data: args.inputs,
};
}
25 changes: 5 additions & 20 deletions packages/inference/src/tasks/cv/imageClassification.ts
Original file line number Diff line number Diff line change
@@ -1,26 +1,10 @@
import type { ImageClassificationInput, ImageClassificationOutput } from "@huggingface/tasks";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";
import { preparePayload, type LegacyImageInput } from "./utils";

export type ImageClassificationArgs = BaseArgs & {
/**
* Binary image data
*/
data: Blob | ArrayBuffer;
};

export interface ImageClassificationOutputValue {
/**
* The label for the class (model specific)
*/
label: string;
/**
* A float that represents how likely it is that the image file belongs to this class.
*/
score: number;
}

export type ImageClassificationOutput = ImageClassificationOutputValue[];
export type ImageClassificationArgs = BaseArgs & (ImageClassificationInput | LegacyImageInput);

/**
* This task reads some image input and outputs the likelihood of classes.
Expand All @@ -30,7 +14,8 @@ export async function imageClassification(
args: ImageClassificationArgs,
options?: Options
): Promise<ImageClassificationOutput> {
const res = await request<ImageClassificationOutput>(args, {
const payload = preparePayload(args);
const res = await request<ImageClassificationOutput>(payload, {
...options,
taskHint: "image-classification",
});
Expand Down
29 changes: 5 additions & 24 deletions packages/inference/src/tasks/cv/imageSegmentation.ts
Original file line number Diff line number Diff line change
@@ -1,30 +1,10 @@
import type { ImageSegmentationInput, ImageSegmentationOutput } from "@huggingface/tasks";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";
import { preparePayload, type LegacyImageInput } from "./utils";

export type ImageSegmentationArgs = BaseArgs & {
/**
* Binary image data
*/
data: Blob | ArrayBuffer;
};

export interface ImageSegmentationOutputValue {
/**
* The label for the class (model specific) of a segment.
*/
label: string;
/**
* A str (base64 str of a single channel black-and-white img) representing the mask of a segment.
*/
mask: string;
/**
* A float that represents how likely it is that the detected object belongs to the given class.
*/
score: number;
}

export type ImageSegmentationOutput = ImageSegmentationOutputValue[];
export type ImageSegmentationArgs = BaseArgs & (ImageSegmentationInput | LegacyImageInput);

/**
* This task reads some image input and outputs the likelihood of classes & bounding boxes of detected objects.
Expand All @@ -34,7 +14,8 @@ export async function imageSegmentation(
args: ImageSegmentationArgs,
options?: Options
): Promise<ImageSegmentationOutput> {
const res = await request<ImageSegmentationOutput>(args, {
const payload = preparePayload(args);
const res = await request<ImageSegmentationOutput>(payload, {
...options,
taskHint: "image-segmentation",
});
Expand Down
Loading
Loading