diff --git a/.github/workflows/inference-check-snippets.yml b/.github/workflows/inference-check-snippets.yml new file mode 100644 index 0000000000..37e9843006 --- /dev/null +++ b/.github/workflows/inference-check-snippets.yml @@ -0,0 +1,42 @@ +name: Inference check snippets +on: + pull_request: + paths: + - "packages/tasks/src/snippets/**" + - ".github/workflows/inference-check-snippets.yml" + +jobs: + check-snippets: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - run: corepack enable + + - uses: actions/setup-node@v3 + with: + node-version: "20" + cache: "pnpm" + cache-dependency-path: "**/pnpm-lock.yaml" + - run: | + cd packages/tasks + pnpm install + + # TODO: Find a way to run on all pipeline tags + # TODO: print snippet only if it has changed since the last commit on main (?) + # TODO: (even better: automated message on the PR with diff) + - name: Print text-to-image snippets + run: | + cd packages/tasks + pnpm run check-snippets --pipeline-tag="text-to-image" + + - name: Print simple text-generation snippets + run: | + cd packages/tasks + pnpm run check-snippets --pipeline-tag="text-generation" + + - name: Print conversational text-generation snippets + run: | + cd packages/tasks + pnpm run check-snippets --pipeline-tag="text-generation" --tags="conversational" diff --git a/packages/tasks/package.json b/packages/tasks/package.json index 318365a7f4..4ef9bc8da4 100644 --- a/packages/tasks/package.json +++ b/packages/tasks/package.json @@ -32,7 +32,8 @@ "check": "tsc", "inference-codegen": "tsx scripts/inference-codegen.ts && prettier --write src/tasks/*/inference.ts", "inference-tgi-import": "tsx scripts/inference-tgi-import.ts && prettier --write src/tasks/text-generation/spec/*.json && prettier --write src/tasks/chat-completion/spec/*.json", - "inference-tei-import": "tsx scripts/inference-tei-import.ts && prettier --write src/tasks/feature-extraction/spec/*.json" + "inference-tei-import": "tsx scripts/inference-tei-import.ts && prettier --write src/tasks/feature-extraction/spec/*.json", + "check-snippets": "tsx scripts/check-snippets.ts" }, "type": "module", "files": [ diff --git a/packages/tasks/scripts/check-snippets.ts b/packages/tasks/scripts/check-snippets.ts new file mode 100644 index 0000000000..dc72282ac1 --- /dev/null +++ b/packages/tasks/scripts/check-snippets.ts @@ -0,0 +1,55 @@ +/* + * Generates inference snippets as they would be shown on the Hub for Curl, JS and Python. + * Snippets will only be printed to the terminal to make it easier to debug when making changes to the snippets. + * + * Usage: + * pnpm run check-snippets --pipeline-tag="text-generation" --tags="conversational" + * pnpm run check-snippets --pipeline-tag="image-text-to-text" --tags="conversational" + * pnpm run check-snippets --pipeline-tag="text-to-image" + * + * This script is meant only for debug purposes. + */ +import { python, curl, js } from "../src/snippets/index"; +import type { InferenceSnippet, ModelDataMinimal } from "../src/snippets/types"; +import type { PipelineType } from "../src/pipelines"; + +// Parse command-line arguments +const args = process.argv.slice(2).reduce( + (acc, arg) => { + const [key, value] = arg.split("="); + acc[key.replace("--", "")] = value; + return acc; + }, + {} as { [key: string]: string } +); + +const accessToken = "hf_**********"; +const pipelineTag = (args["pipeline-tag"] || "text-generation") as PipelineType; +const tags = (args["tags"] ?? "").split(","); + +const modelMinimal: ModelDataMinimal = { + id: "llama-6-1720B-Instruct", + pipeline_tag: pipelineTag, + tags: tags, + inference: "****", +}; + +const printSnippets = (snippets: InferenceSnippet | InferenceSnippet[], language: string) => { + const snippetArray = Array.isArray(snippets) ? snippets : [snippets]; + snippetArray.forEach((snippet) => { + console.log(`\n\x1b[33m${language} ${snippet.client}\x1b[0m`); + console.log(`\n\`\`\`${language}\n${snippet.content}\n\`\`\`\n`); + }); +}; + +const generateAndPrintSnippets = ( + generator: (model: ModelDataMinimal, token: string) => InferenceSnippet | InferenceSnippet[], + language: string +) => { + const snippets = generator(modelMinimal, accessToken); + printSnippets(snippets, language); +}; + +generateAndPrintSnippets(curl.getCurlInferenceSnippet, "curl"); +generateAndPrintSnippets(python.getPythonInferenceSnippet, "python"); +generateAndPrintSnippets(js.getJsInferenceSnippet, "js"); diff --git a/packages/tasks/src/snippets/python.ts b/packages/tasks/src/snippets/python.ts index d2b0f25850..203f03f05d 100644 --- a/packages/tasks/src/snippets/python.ts +++ b/packages/tasks/src/snippets/python.ts @@ -4,6 +4,11 @@ import { stringifyGenerationConfig, stringifyMessages } from "./common.js"; import { getModelInputSnippet } from "./inputs.js"; import type { InferenceSnippet, ModelDataMinimal } from "./types.js"; +const snippetImportInferenceClient = (model: ModelDataMinimal, accessToken: string): string => + `from huggingface_hub import InferenceClient + +client = InferenceClient(${model.id}, token="${accessToken || "{API_TOKEN}"}")`; + export const snippetConversational = ( model: ModelDataMinimal, accessToken: string, @@ -168,10 +173,14 @@ export const snippetFile = (model: ModelDataMinimal): InferenceSnippet => ({ output = query(${getModelInputSnippet(model)})`, }); -export const snippetTextToImage = (model: ModelDataMinimal): InferenceSnippet => ({ - content: `def query(payload): +export const snippetTextToImage = (model: ModelDataMinimal, accessToken: string): InferenceSnippet[] => { + return [ + { + client: "requests", + content: `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.content + image_bytes = query({ "inputs": ${getModelInputSnippet(model)}, }) @@ -179,7 +188,16 @@ image_bytes = query({ import io from PIL import Image image = Image.open(io.BytesIO(image_bytes))`, -}); + }, + { + client: "huggingface_hub", + content: `${snippetImportInferenceClient(model, accessToken)} + +# output is a PIL.Image object +image = client.text_to_image(${getModelInputSnippet(model)})`, + }, + ]; +}; export const snippetTabular = (model: ModelDataMinimal): InferenceSnippet => ({ content: `def query(payload): @@ -284,6 +302,9 @@ export function getPythonInferenceSnippet( if (model.tags.includes("conversational")) { // Conversational model detected, so we display a code snippet that features the Messages API return snippetConversational(model, accessToken, opts); + } else if (model.pipeline_tag == "text-to-image") { + // TODO: factorize this logic + return snippetTextToImage(model, accessToken); } else { let snippets = model.pipeline_tag && model.pipeline_tag in pythonSnippets