diff --git a/packages/tasks/src/local-apps.spec.ts b/packages/tasks/src/local-apps.spec.ts new file mode 100644 index 0000000000..23806f668d --- /dev/null +++ b/packages/tasks/src/local-apps.spec.ts @@ -0,0 +1,123 @@ +import { describe, expect, it } from "vitest"; +import { LOCAL_APPS } from "./local-apps.js"; +import type { ModelData } from "./model-data.js"; + +describe("local-apps", () => { + it("llama.cpp conversational", async () => { + const { snippet: snippetFunc } = LOCAL_APPS["llama.cpp"]; + const model: ModelData = { + id: "bartowski/Llama-3.2-3B-Instruct-GGUF", + tags: ["conversational"], + inference: "", + }; + const snippet = snippetFunc(model); + + expect(snippet[0].content).toEqual(`# Load and run the model: +llama-cli \\ + --hf-repo "bartowski/Llama-3.2-3B-Instruct-GGUF" \\ + --hf-file {{GGUF_FILE}} \\ + -p "You are a helpful assistant" \\ + --conversation`); + }); + + it("llama.cpp non-conversational", async () => { + const { snippet: snippetFunc } = LOCAL_APPS["llama.cpp"]; + const model: ModelData = { + id: "mlabonne/gemma-2b-GGUF", + tags: [], + inference: "", + }; + const snippet = snippetFunc(model); + + expect(snippet[0].content).toEqual(`# Load and run the model: +llama-cli \\ + --hf-repo "mlabonne/gemma-2b-GGUF" \\ + --hf-file {{GGUF_FILE}} \\ + -p "Once upon a time,"`); + }); + + it("vLLM conversational llm", async () => { + const { snippet: snippetFunc } = LOCAL_APPS["vllm"]; + const model: ModelData = { + id: "meta-llama/Llama-3.2-3B-Instruct", + pipeline_tag: "text-generation", + tags: ["conversational"], + inference: "", + }; + const snippet = snippetFunc(model); + + expect((snippet[0].content as string[]).join("\n")).toEqual(`# Load and run the model: +vllm serve "meta-llama/Llama-3.2-3B-Instruct" +# Call the server using curl: +curl -X POST "http://localhost:8000/v1/chat/completions" \\ + -H "Content-Type: application/json" \\ + --data '{ + "model": "meta-llama/Llama-3.2-3B-Instruct", + "messages": [ + { + "role": "user", + "content": "What is the capital of France?" + } + ] + }'`); + }); + + it("vLLM non-conversational llm", async () => { + const { snippet: snippetFunc } = LOCAL_APPS["vllm"]; + const model: ModelData = { + id: "meta-llama/Llama-3.2-3B", + tags: [""], + inference: "", + }; + const snippet = snippetFunc(model); + + expect((snippet[0].content as string[]).join("\n")).toEqual(`# Load and run the model: +vllm serve "meta-llama/Llama-3.2-3B" +# Call the server using curl: +curl -X POST "http://localhost:8000/v1/completions" \\ + -H "Content-Type: application/json" \\ + --data '{ + "model": "meta-llama/Llama-3.2-3B", + "prompt": "Once upon a time,", + "max_tokens": 512, + "temperature": 0.5 + }'`); + }); + + it("vLLM conversational vlm", async () => { + const { snippet: snippetFunc } = LOCAL_APPS["vllm"]; + const model: ModelData = { + id: "meta-llama/Llama-3.2-11B-Vision-Instruct", + pipeline_tag: "image-text-to-text", + tags: ["conversational"], + inference: "", + }; + const snippet = snippetFunc(model); + + expect((snippet[0].content as string[]).join("\n")).toEqual(`# Load and run the model: +vllm serve "meta-llama/Llama-3.2-11B-Vision-Instruct" +# Call the server using curl: +curl -X POST "http://localhost:8000/v1/chat/completions" \\ + -H "Content-Type: application/json" \\ + --data '{ + "model": "meta-llama/Llama-3.2-11B-Vision-Instruct", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Describe this image in one sentence." + }, + { + "type": "image_url", + "image_url": { + "url": "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg" + } + } + ] + } + ] + }'`); + }); +}); diff --git a/packages/tasks/src/local-apps.ts b/packages/tasks/src/local-apps.ts index edc7e64fd8..2249183a4c 100644 --- a/packages/tasks/src/local-apps.ts +++ b/packages/tasks/src/local-apps.ts @@ -1,6 +1,9 @@ import { parseGGUFQuantLabel } from "./gguf.js"; import type { ModelData } from "./model-data.js"; import type { PipelineType } from "./pipelines.js"; +import { stringifyMessages } from "./snippets/common.js"; +import { getModelInputSnippet } from "./snippets/inputs.js"; +import type { ChatCompletionInputMessage } from "./tasks/index.js"; export interface LocalAppSnippet { /** @@ -92,15 +95,20 @@ function isMlxModel(model: ModelData) { } const snippetLlamacpp = (model: ModelData, filepath?: string): LocalAppSnippet[] => { - const command = (binary: string) => - [ + const command = (binary: string) => { + const snippet = [ "# Load and run the model:", `${binary} \\`, ` --hf-repo "${model.id}" \\`, ` --hf-file ${filepath ?? "{{GGUF_FILE}}"} \\`, - ' -p "You are a helpful assistant" \\', - " --conversation", - ].join("\n"); + ` -p "${model.tags.includes("conversational") ? "You are a helpful assistant" : "Once upon a time,"}"`, + ]; + if (model.tags.includes("conversational")) { + snippet[snippet.length - 1] += " \\"; + snippet.push(" --conversation"); + } + return snippet.join("\n"); + }; return [ { title: "Install from brew", @@ -178,22 +186,33 @@ const snippetLocalAI = (model: ModelData, filepath?: string): LocalAppSnippet[] }; const snippetVllm = (model: ModelData): LocalAppSnippet[] => { - const runCommand = [ - "# Call the server using curl:", - `curl -X POST "http://localhost:8000/v1/chat/completions" \\`, - ` -H "Content-Type: application/json" \\`, - ` --data '{`, - ` "model": "${model.id}",`, - ` "messages": [`, - ` {"role": "user", "content": "Hello!"}`, - ` ]`, - ` }'`, - ]; + const messages = getModelInputSnippet(model) as ChatCompletionInputMessage[]; + const runCommandInstruct = `# Call the server using curl: +curl -X POST "http://localhost:8000/v1/chat/completions" \\ + -H "Content-Type: application/json" \\ + --data '{ + "model": "${model.id}", + "messages": ${stringifyMessages(messages, { + indent: "\t\t", + attributeKeyQuotes: true, + customContentEscaper: (str) => str.replace(/'/g, "'\\''"), + })} + }'`; + const runCommandNonInstruct = `# Call the server using curl: +curl -X POST "http://localhost:8000/v1/completions" \\ + -H "Content-Type: application/json" \\ + --data '{ + "model": "${model.id}", + "prompt": "Once upon a time,", + "max_tokens": 512, + "temperature": 0.5 + }'`; + const runCommand = model.tags.includes("conversational") ? runCommandInstruct : runCommandNonInstruct; return [ { title: "Install from pip", setup: ["# Install vLLM from pip:", "pip install vllm"].join("\n"), - content: [`# Load and run the model:\nvllm serve "${model.id}"`, runCommand.join("\n")], + content: [`# Load and run the model:\nvllm serve "${model.id}"`, runCommand], }, { title: "Use Docker images", @@ -210,7 +229,7 @@ const snippetVllm = (model: ModelData): LocalAppSnippet[] => { ].join("\n"), content: [ `# Load and run the model:\ndocker exec -it my_vllm_container bash -c "vllm serve ${model.id}"`, - runCommand.join("\n"), + runCommand, ], }, ]; diff --git a/packages/tasks/src/model-libraries-snippets.spec.ts b/packages/tasks/src/model-libraries-snippets.spec.ts new file mode 100644 index 0000000000..fa87d82423 --- /dev/null +++ b/packages/tasks/src/model-libraries-snippets.spec.ts @@ -0,0 +1,54 @@ +import { describe, expect, it } from "vitest"; +import type { ModelData } from "./model-data.js"; +import { llama_cpp_python } from "./model-libraries-snippets.js"; + +describe("model-libraries-snippets", () => { + it("llama_cpp_python conversational", async () => { + const model: ModelData = { + id: "bartowski/Llama-3.2-3B-Instruct-GGUF", + pipeline_tag: "text-generation", + tags: ["conversational"], + inference: "", + }; + const snippet = llama_cpp_python(model); + + expect(snippet.join("\n")).toEqual(`from llama_cpp import Llama + +llm = Llama.from_pretrained( + repo_id="bartowski/Llama-3.2-3B-Instruct-GGUF", + filename="{{GGUF_FILE}}", +) + +llm.create_chat_completion( + messages = [ + { + "role": "user", + "content": "What is the capital of France?" + } + ] +)`); + }); + + it("llama_cpp_python non-conversational", async () => { + const model: ModelData = { + id: "mlabonne/gemma-2b-GGUF", + tags: [""], + inference: "", + }; + const snippet = llama_cpp_python(model); + + expect(snippet.join("\n")).toEqual(`from llama_cpp import Llama + +llm = Llama.from_pretrained( + repo_id="mlabonne/gemma-2b-GGUF", + filename="{{GGUF_FILE}}", +) + +output = llm( + "Once upon a time,", + max_tokens=512, + echo=True +) +print(output)`); + }); +}); diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 523d1a2457..bb6ac12c3c 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -1,6 +1,9 @@ import type { ModelData } from "./model-data.js"; import type { WidgetExampleTextInput, WidgetExampleSentenceSimilarityInput } from "./widget-example.js"; import { LIBRARY_TASK_MAPPING } from "./library-to-tasks.js"; +import { getModelInputSnippet } from "./snippets/inputs.js"; +import type { ChatCompletionInputMessage } from "./tasks/index.js"; +import { stringifyMessages } from "./snippets/common.js"; const TAG_CUSTOM_CODE = "custom_code"; @@ -418,23 +421,33 @@ model = keras_hub.models.CausalLM.from_preset("hf://${model.id}", dtype="bfloat1 `, ]; -export const llama_cpp_python = (model: ModelData): string[] => [ - `from llama_cpp import Llama +export const llama_cpp_python = (model: ModelData): string[] => { + const snippets = [ + `from llama_cpp import Llama llm = Llama.from_pretrained( repo_id="${model.id}", filename="{{GGUF_FILE}}", ) +`, + ]; -llm.create_chat_completion( - messages = [ - { - "role": "user", - "content": "What is the capital of France?" - } - ] -)`, -]; + if (model.tags.includes("conversational")) { + const messages = getModelInputSnippet(model) as ChatCompletionInputMessage[]; + snippets.push(`llm.create_chat_completion( + messages = ${stringifyMessages(messages, { attributeKeyQuotes: true, indent: "\t" })} +)`); + } else { + snippets.push(`output = llm( + "Once upon a time,", + max_tokens=512, + echo=True +) +print(output)`); + } + + return snippets; +}; export const tf_keras = (model: ModelData): string[] => [ `# Note: 'keras<3.x' or 'tf_keras' must be installed (legacy)