Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Hermes-2-Pro function calling example with JSON schema #390

Merged
merged 2 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/json-schema/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
"url": "^0.11.3"
},
"dependencies": {
"@mlc-ai/web-llm": "^0.2.35"
"@mlc-ai/web-llm": "file:../.."
}
}
254 changes: 205 additions & 49 deletions examples/json-schema/src/json_schema.ts
Original file line number Diff line number Diff line change
@@ -1,63 +1,219 @@
import * as webllm from "@mlc-ai/web-llm";
import { Type, Static } from "@sinclair/typebox";

function setLabel(id: string, text: string) {
const label = document.getElementById(id);
if (label == null) {
throw Error("Cannot find label " + id);
}
label.innerText = text;
const label = document.getElementById(id);
if (label == null) {
throw Error("Cannot find label " + id);
}
label.innerText = text;
}

// There are several options of providing such a schema
// 1. You can directly define a schema in string
const schema1 = `{
"properties": {
"size": {"title": "Size", "type": "integer"},
"is_accepted": {"title": "Is Accepted", "type": "boolean"},
"num": {"title": "Num", "type": "number"}
},
"required": ["size", "is_accepted", "num"],
"title": "Schema", "type": "object"
}`;
async function simpleStructuredTextExample() {
// There are several options of providing such a schema
// 1. You can directly define a schema in string
const schema1 = `{
"properties": {
"size": {"title": "Size", "type": "integer"},
"is_accepted": {"title": "Is Accepted", "type": "boolean"},
"num": {"title": "Num", "type": "number"}
},
"required": ["size", "is_accepted", "num"],
"title": "Schema", "type": "object"
}`;

// 2. You can use 3rdparty libraries like typebox to create a schema
import { Type, type Static } from '@sinclair/typebox'
const T = Type.Object({
// 2. You can use 3rdparty libraries like typebox to create a schema
const T = Type.Object({
size: Type.Integer(),
is_accepted: Type.Boolean(),
num: Type.Number(),
})
type T = Static<typeof T>;
const schema2 = JSON.stringify(T);
console.log(schema2);
// {"type":"object","properties":{"size":{"type":"integer"},"is_accepted":{"type":"boolean"},
// "num":{"type":"number"}},"required":["size","is_accepted","num"]}
});
type T = Static<typeof T>;
const schema2 = JSON.stringify(T);
console.log(schema2);
// {"type":"object","properties":{"size":{"type":"integer"},"is_accepted":{"type":"boolean"},
// "num":{"type":"number"}},"required":["size","is_accepted","num"]}

const initProgressCallback = (report: webllm.InitProgressReport) => {
setLabel("init-label", report.text);
};
const engine: webllm.EngineInterface = await webllm.CreateEngine(
"Llama-2-7b-chat-hf-q4f16_1",
{ initProgressCallback: initProgressCallback }
);

const request: webllm.ChatCompletionRequest = {
stream: false, // works with streaming, logprobs, top_logprobs as well
messages: [
{
role: "user",
content:
"Generate a json containing three fields: an integer field named size, a " +
"boolean field named is_accepted, and a float field named num.",
},
],
max_gen_len: 128,
response_format: {
type: "json_object",
schema: schema2,
} as webllm.ResponseFormat,
};

const reply0 = await engine.chatCompletion(request);
console.log(reply0);
console.log("Output:\n" + (await engine.getMessage()));
console.log(await engine.runtimeStatsText());
}

// The json schema and prompt is taken from
// https://github.com/sgl-project/sglang/tree/main?tab=readme-ov-file#json-decoding
async function harryPotterExample() {
const T = Type.Object({
name: Type.String(),
house: Type.Enum({
Gryffindor: "Gryffindor",
Hufflepuff: "Hufflepuff",
Ravenclaw: "Ravenclaw",
Slytherin: "Slytherin",
}),
blood_status: Type.Enum({
"Pure-blood": "Pure-blood",
"Half-blood": "Half-blood",
"Muggle-born": "Muggle-born",
}),
occupation: Type.Enum({
Student: "Student",
Professor: "Professor",
"Ministry of Magic": "Ministry of Magic",
Other: "Other",
}),
wand: Type.Object({
wood: Type.String(),
core: Type.String(),
length: Type.Number(),
}),
alive: Type.Boolean(),
patronus: Type.String(),
});

type T = Static<typeof T>;
const schema = JSON.stringify(T);
console.log(schema);

const initProgressCallback = (report: webllm.InitProgressReport) => {
setLabel("init-label", report.text);
};

const engine: webllm.EngineInterface = await webllm.CreateEngine(
"Llama-2-7b-chat-hf-q4f16_1",
{ initProgressCallback: initProgressCallback }
);

const request: webllm.ChatCompletionRequest = {
stream: false,
messages: [
{
role: "user",
content:
"Hermione Granger is a character in Harry Potter. Please fill in the following information about this character in JSON format." +
"Name is a string of character name. House is one of Gryffindor, Hufflepuff, Ravenclaw, Slytherin. Blood status is one of Pure-blood, Half-blood, Muggle-born. Occupation is one of Student, Professor, Ministry of Magic, Other. Wand is an object with wood, core, and length. Alive is a boolean. Patronus is a string.",
},
],
max_gen_len: 128,
response_format: {
type: "json_object",
schema: schema,
} as webllm.ResponseFormat,
};

const reply = await engine.chatCompletion(request);
console.log(reply);
console.log("Output:\n" + (await engine.getMessage()));
console.log(await engine.runtimeStatsText());
}

async function functionCallingExample() {
const T = Type.Object({
tool_calls: Type.Array(
Type.Object({
arguments: Type.Any(),
name: Type.String(),
})
),
});
type T = Static<typeof T>;
const schema = JSON.stringify(T);
console.log(schema);

const tools: Array<webllm.ChatCompletionTool> = [
{
type: "function",
function: {
name: "get_current_weather",
description: "Get the current weather in a given location",
parameters: {
type: "object",
properties: {
location: {
type: "string",
description: "The city and state, e.g. San Francisco, CA",
},
unit: { type: "string", enum: ["celsius", "fahrenheit"] },
},
required: ["location"],
},
},
},
];

const initProgressCallback = (report: webllm.InitProgressReport) => {
setLabel("init-label", report.text);
};

const selectedModel = "Hermes-2-Pro-Mistral-7B-q4f16_1";
const engine: webllm.EngineInterface = await webllm.CreateEngine(
selectedModel,
{
initProgressCallback: initProgressCallback,
}
);

const request: webllm.ChatCompletionRequest = {
stream: false,
messages: [
{
role: "system",
content: `You are a function calling AI model. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: <tools> ${JSON.stringify(
tools
)} </tools>. Do not stop calling functions until the task has been accomplished or you've reached max iteration of 10.
Calling multiple functions at once can overload the system and increase cost so call one function at a time please.
If you plan to continue with analysis, always call another function.
Return a valid json object (using double quotes) in the following schema: ${JSON.stringify(
schema
)}.`,
},
{
role: "user",
content:
"What is the current weather in celsius in Pittsburgh and Tokyo?",
},
],
response_format: {
type: "json_object",
schema: schema,
} as webllm.ResponseFormat,
};

const reply = await engine.chat.completions.create(request);
console.log(reply.choices[0].message.content);

console.log(await engine.runtimeStatsText());
}

async function main() {
const initProgressCallback = (report: webllm.InitProgressReport) => {
setLabel("init-label", report.text);
};
const engine: webllm.EngineInterface = await webllm.CreateEngine(
"Llama-2-7b-chat-hf-q4f16_1", { initProgressCallback: initProgressCallback }
);

const request: webllm.ChatCompletionRequest = {
stream: false, // works with streaming, logprobs, top_logprobs as well
messages: [
{
"role": "user",
"content": "Generate a json containing three fields: an integer field named size, a " +
"boolean field named is_accepted, and a float field named num."
}
],
max_gen_len: 128,
response_format: { type: "json_object", schema: schema2 } as webllm.ResponseFormat
};

const reply0 = await engine.chatCompletion(request);
console.log(reply0);
console.log("Output:\n" + await engine.getMessage());
console.log(await engine.runtimeStatsText());
// await simpleStructuredTextExample();
// await harryPotterExample();
await functionCallingExample();
}

main();
2 changes: 1 addition & 1 deletion package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions src/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,14 @@ export const prebuiltAppConfig: AppConfig = {
"low_resource_required": false,
"required_features": ["shader-f16"],
},
{
"model_url": "https://huggingface.co/mlc-ai/Hermes-2-Pro-Mistral-7B-q4f16_1-MLC/resolve/main/",
"model_id": "Hermes-2-Pro-Mistral-7B-q4f16_1",
"model_lib_url": modelLibURLPrefix + modelVersion + "/Hermes-2-Pro-Mistral-7B-q4f16_1-sw4k_cs1k-webgpu.wasm",
"vram_required_MB": 4033.28,
"low_resource_required": false,
"required_features": ["shader-f16"],
},
// Gemma-2B
{
"model_url": "https://huggingface.co/mlc-ai/gemma-2b-it-q4f16_1-MLC/resolve/main/",
Expand Down