diff --git a/packages/tasks/package.json b/packages/tasks/package.json index 318365a7f4..77318d6961 100644 --- a/packages/tasks/package.json +++ b/packages/tasks/package.json @@ -30,6 +30,7 @@ "watch": "npm-run-all --parallel watch:export watch:types", "prepare": "pnpm run build", "check": "tsc", + "test": "vitest run", "inference-codegen": "tsx scripts/inference-codegen.ts && prettier --write src/tasks/*/inference.ts", "inference-tgi-import": "tsx scripts/inference-tgi-import.ts && prettier --write src/tasks/text-generation/spec/*.json && prettier --write src/tasks/chat-completion/spec/*.json", "inference-tei-import": "tsx scripts/inference-tei-import.ts && prettier --write src/tasks/feature-extraction/spec/*.json" diff --git a/packages/tasks/src/snippets/common.ts b/packages/tasks/src/snippets/common.ts index 24b0613406..0f82db8150 100644 --- a/packages/tasks/src/snippets/common.ts +++ b/packages/tasks/src/snippets/common.ts @@ -1,63 +1,39 @@ import type { ChatCompletionInputMessage, GenerationParameters } from "../tasks"; -export interface StringifyMessagesOptions { - sep: string; - start: string; - end: string; - attributeKeyQuotes?: boolean; - customContentEscaper?: (str: string) => string; -} - -export function stringifyMessages(messages: ChatCompletionInputMessage[], opts: StringifyMessagesOptions): string { - const keyRole = opts.attributeKeyQuotes ? `"role"` : "role"; - const keyContent = opts.attributeKeyQuotes ? `"content"` : "content"; - - const messagesStringified = messages.map(({ role, content }) => { - if (typeof content === "string") { - content = JSON.stringify(content).slice(1, -1); - if (opts.customContentEscaper) { - content = opts.customContentEscaper(content); - } - return `{ ${keyRole}: "${role}", ${keyContent}: "${content}" }`; - } else { - 2; - content = content.map(({ image_url, text, type }) => ({ - type, - image_url, - ...(text ? { text: JSON.stringify(text).slice(1, -1) } : undefined), - })); - content = JSON.stringify(content).slice(1, -1); - if (opts.customContentEscaper) { - content = opts.customContentEscaper(content); - } - return `{ ${keyRole}: "${role}", ${keyContent}: [${content}] }`; - } - }); - - return opts.start + messagesStringified.join(opts.sep) + opts.end; +export function stringifyMessages( + messages: ChatCompletionInputMessage[], + opts?: { + indent?: string; + attributeKeyQuotes?: boolean; + customContentEscaper?: (str: string) => string; + } +): string { + let messagesStr = JSON.stringify(messages, null, "\t"); + if (opts?.indent) { + messagesStr = messagesStr.replaceAll("\n", `\n${opts.indent}`); + } + if (!opts?.attributeKeyQuotes) { + messagesStr = messagesStr.replace(/"([^"]+)":/g, "$1:"); + } + if (opts?.customContentEscaper) { + messagesStr = opts.customContentEscaper(messagesStr); + } + return messagesStr; } type PartialGenerationParameters = Partial>; -export interface StringifyGenerationConfigOptions { - sep: string; - start: string; - end: string; - attributeValueConnector: string; - attributeKeyQuotes?: boolean; -} - export function stringifyGenerationConfig( config: PartialGenerationParameters, - opts: StringifyGenerationConfigOptions + opts: { + indent: string; + attributeValueConnector: string; + attributeKeyQuotes?: boolean; + } ): string { const quote = opts.attributeKeyQuotes ? `"` : ""; - return ( - opts.start + - Object.entries(config) - .map(([key, val]) => `${quote}${key}${quote}${opts.attributeValueConnector}${val}`) - .join(opts.sep) + - opts.end - ); + return Object.entries(config) + .map(([key, val]) => `${quote}${key}${quote}${opts.attributeValueConnector}${val}`) + .join(`,${opts.indent}`); } diff --git a/packages/tasks/src/snippets/curl.spec.ts b/packages/tasks/src/snippets/curl.spec.ts new file mode 100644 index 0000000000..0b854dbb8e --- /dev/null +++ b/packages/tasks/src/snippets/curl.spec.ts @@ -0,0 +1,68 @@ +import type { ModelDataMinimal } from "./types"; +import { describe, expect, it } from "vitest"; +import { snippetTextGeneration } from "./curl"; + +describe("inference API snippets", () => { + it("conversational llm", async () => { + const model: ModelDataMinimal = { + id: "meta-llama/Llama-3.1-8B-Instruct", + pipeline_tag: "text-generation", + tags: ["conversational"], + inference: "", + }; + const snippet = snippetTextGeneration(model, "api_token"); + + expect(snippet.content) + .toEqual(`curl 'https://api-inference.huggingface.co/models/meta-llama/Llama-3.1-8B-Instruct/v1/chat/completions' \\ +-H "Authorization: Bearer api_token" \\ +-H 'Content-Type: application/json' \\ +--data '{ + "model": "meta-llama/Llama-3.1-8B-Instruct", + "messages": [ + { + "role": "user", + "content": "What is the capital of France?" + } + ], + "max_tokens": 500, + "stream": true +}'`); + }); + + it("conversational vlm", async () => { + const model: ModelDataMinimal = { + id: "meta-llama/Llama-3.2-11B-Vision-Instruct", + pipeline_tag: "image-text-to-text", + tags: ["conversational"], + inference: "", + }; + const snippet = snippetTextGeneration(model, "api_token"); + + expect(snippet.content) + .toEqual(`curl 'https://api-inference.huggingface.co/models/meta-llama/Llama-3.2-11B-Vision-Instruct/v1/chat/completions' \\ +-H "Authorization: Bearer api_token" \\ +-H 'Content-Type: application/json' \\ +--data '{ + "model": "meta-llama/Llama-3.2-11B-Vision-Instruct", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Describe this image in one sentence." + }, + { + "type": "image_url", + "image_url": { + "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" + } + } + ] + } + ], + "max_tokens": 500, + "stream": true +}'`); + }); +}); diff --git a/packages/tasks/src/snippets/curl.ts b/packages/tasks/src/snippets/curl.ts index af4ada2678..28f8a59c7c 100644 --- a/packages/tasks/src/snippets/curl.ts +++ b/packages/tasks/src/snippets/curl.ts @@ -41,16 +41,12 @@ export const snippetTextGeneration = ( --data '{ "model": "${model.id}", "messages": ${stringifyMessages(messages, { - sep: ",\n\t\t", - start: `[\n\t\t`, - end: `\n\t]`, + indent: "\t", attributeKeyQuotes: true, customContentEscaper: (str) => str.replace(/'/g, "'\\''"), })}, ${stringifyGenerationConfig(config, { - sep: ",\n ", - start: "", - end: "", + indent: "\n ", attributeKeyQuotes: true, attributeValueConnector: ": ", })}, diff --git a/packages/tasks/src/snippets/inputs.ts b/packages/tasks/src/snippets/inputs.ts index 6a0404bdfb..70afde388c 100644 --- a/packages/tasks/src/snippets/inputs.ts +++ b/packages/tasks/src/snippets/inputs.ts @@ -128,6 +128,7 @@ const modelInputSnippets: { "tabular-classification": inputsTabularPrediction, "text-classification": inputsTextClassification, "text-generation": inputsTextGeneration, + "image-text-to-text": inputsTextGeneration, "text-to-image": inputsTextToImage, "text-to-speech": inputsTextToSpeech, "text-to-audio": inputsTextToAudio, diff --git a/packages/tasks/src/snippets/js.spec.ts b/packages/tasks/src/snippets/js.spec.ts new file mode 100644 index 0000000000..7780707211 --- /dev/null +++ b/packages/tasks/src/snippets/js.spec.ts @@ -0,0 +1,86 @@ +import type { InferenceSnippet, ModelDataMinimal } from "./types"; +import { describe, expect, it } from "vitest"; +import { snippetTextGeneration } from "./js"; + +describe("inference API snippets", () => { + it("conversational llm", async () => { + const model: ModelDataMinimal = { + id: "meta-llama/Llama-3.1-8B-Instruct", + pipeline_tag: "text-generation", + tags: ["conversational"], + inference: "", + }; + const snippet = snippetTextGeneration(model, "api_token") as InferenceSnippet[]; + + expect(snippet[0].content).toEqual(`import { HfInference } from "@huggingface/inference" + +const client = new HfInference("api_token") + +let out = ""; + +const stream = client.chatCompletionStream({ + model: "meta-llama/Llama-3.1-8B-Instruct", + messages: [ + { + role: "user", + content: "What is the capital of France?" + } + ], + max_tokens: 500 +}); + +for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + const newContent = chunk.choices[0].delta.content; + out += newContent; + console.log(newContent); + } +}`); + }); + + it("conversational vlm", async () => { + const model: ModelDataMinimal = { + id: "meta-llama/Llama-3.2-11B-Vision-Instruct", + pipeline_tag: "image-text-to-text", + tags: ["conversational"], + inference: "", + }; + const snippet = snippetTextGeneration(model, "api_token") as InferenceSnippet[]; + + expect(snippet[0].content).toEqual(`import { HfInference } from "@huggingface/inference" + +const client = new HfInference("api_token") + +let out = ""; + +const stream = client.chatCompletionStream({ + model: "meta-llama/Llama-3.2-11B-Vision-Instruct", + messages: [ + { + role: "user", + content: [ + { + type: "text", + text: "Describe this image in one sentence." + }, + { + type: "image_url", + image_url: { + url: "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" + } + } + ] + } + ], + max_tokens: 500 +}); + +for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + const newContent = chunk.choices[0].delta.content; + out += newContent; + console.log(newContent); + } +}`); + }); +}); diff --git a/packages/tasks/src/snippets/js.ts b/packages/tasks/src/snippets/js.ts index c261e08a1d..fdf25b46c5 100644 --- a/packages/tasks/src/snippets/js.ts +++ b/packages/tasks/src/snippets/js.ts @@ -42,7 +42,7 @@ export const snippetTextGeneration = ( const streaming = opts?.streaming ?? true; const exampleMessages = getModelInputSnippet(model) as ChatCompletionInputMessage[]; const messages = opts?.messages ?? exampleMessages; - const messagesStr = stringifyMessages(messages, { sep: ",\n\t\t", start: "[\n\t\t", end: "\n\t]" }); + const messagesStr = stringifyMessages(messages, { indent: "\t" }); const config = { ...(opts?.temperature ? { temperature: opts.temperature } : undefined), @@ -50,9 +50,7 @@ export const snippetTextGeneration = ( ...(opts?.top_p ? { top_p: opts.top_p } : undefined), }; const configStr = stringifyGenerationConfig(config, { - sep: ",\n\t", - start: "", - end: "", + indent: "\n\t", attributeValueConnector: ": ", }); diff --git a/packages/tasks/src/snippets/python.spec.ts b/packages/tasks/src/snippets/python.spec.ts new file mode 100644 index 0000000000..3f1ee4979a --- /dev/null +++ b/packages/tasks/src/snippets/python.spec.ts @@ -0,0 +1,78 @@ +import type { ModelDataMinimal } from "./types"; +import { describe, expect, it } from "vitest"; +import { snippetConversational } from "./python"; + +describe("inference API snippets", () => { + it("conversational llm", async () => { + const model: ModelDataMinimal = { + id: "meta-llama/Llama-3.1-8B-Instruct", + pipeline_tag: "text-generation", + tags: ["conversational"], + inference: "", + }; + const snippet = snippetConversational(model, "api_token"); + + expect(snippet[0].content).toEqual(`from huggingface_hub import InferenceClient + +client = InferenceClient(api_key="api_token") + +messages = [ + { + "role": "user", + "content": "What is the capital of France?" + } +] + +stream = client.chat.completions.create( + model="meta-llama/Llama-3.1-8B-Instruct", + messages=messages, + max_tokens=500, + stream=True +) + +for chunk in stream: + print(chunk.choices[0].delta.content, end="")`); + }); + + it("conversational vlm", async () => { + const model: ModelDataMinimal = { + id: "meta-llama/Llama-3.2-11B-Vision-Instruct", + pipeline_tag: "image-text-to-text", + tags: ["conversational"], + inference: "", + }; + const snippet = snippetConversational(model, "api_token"); + + expect(snippet[0].content).toEqual(`from huggingface_hub import InferenceClient + +client = InferenceClient(api_key="api_token") + +messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Describe this image in one sentence." + }, + { + "type": "image_url", + "image_url": { + "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" + } + } + ] + } +] + +stream = client.chat.completions.create( + model="meta-llama/Llama-3.2-11B-Vision-Instruct", + messages=messages, + max_tokens=500, + stream=True +) + +for chunk in stream: + print(chunk.choices[0].delta.content, end="")`); + }); +}); diff --git a/packages/tasks/src/snippets/python.ts b/packages/tasks/src/snippets/python.ts index d2b0f25850..31ce47a10b 100644 --- a/packages/tasks/src/snippets/python.ts +++ b/packages/tasks/src/snippets/python.ts @@ -18,12 +18,7 @@ export const snippetConversational = ( const streaming = opts?.streaming ?? true; const exampleMessages = getModelInputSnippet(model) as ChatCompletionInputMessage[]; const messages = opts?.messages ?? exampleMessages; - const messagesStr = stringifyMessages(messages, { - sep: ",\n\t", - start: `[\n\t`, - end: `\n]`, - attributeKeyQuotes: true, - }); + const messagesStr = stringifyMessages(messages, { attributeKeyQuotes: true }); const config = { ...(opts?.temperature ? { temperature: opts.temperature } : undefined), @@ -31,9 +26,7 @@ export const snippetConversational = ( ...(opts?.top_p ? { top_p: opts.top_p } : undefined), }; const configStr = stringifyGenerationConfig(config, { - sep: ",\n\t", - start: "", - end: "", + indent: "\n\t", attributeValueConnector: "=", }); @@ -55,7 +48,7 @@ stream = client.chat.completions.create( ) for chunk in stream: - print(chunk.choices[0].delta.content)`, + print(chunk.choices[0].delta.content, end="")`, }, { client: "openai", @@ -76,7 +69,7 @@ stream = client.chat.completions.create( ) for chunk in stream: - print(chunk.choices[0].delta.content)`, + print(chunk.choices[0].delta.content, end="")`, }, ]; } else {