diff --git a/js/README.md b/js/README.md index de8d378e9..dcd93fbd3 100644 --- a/js/README.md +++ b/js/README.md @@ -49,3 +49,6 @@ Click links for README of each examples. * [Facebook Segment-Anything](segment-anything) - demonstrates how to run [segment-anything](https://github.com/facebookresearch/segment-anything) in your browser using [onnxruntime-web](https://github.com/microsoft/onnxruntime/js) with webgpu. * [Stable Diffusion Turbo](sd-turbo) - demonstrates how to run [Stable Diffusion Turbo](https://huggingface.co/stabilityai/sd-turbo) in your browser using [onnxruntime-web](https://github.com/microsoft/onnxruntime/js) with webgpu. + +* [Phi-3-mini-4k-instruct](chat) - demonstrates how to run [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) in your browser using [onnxruntime-web](https://github.com/microsoft/onnxruntime/js) with webgpu. + diff --git a/js/chat/README.md b/js/chat/README.md new file mode 100644 index 000000000..0b9e2d7f3 --- /dev/null +++ b/js/chat/README.md @@ -0,0 +1,54 @@ +# Local Chat using Phi3, ONNX Runtime Web and WebGPU + +This repository contains an example of running [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) in your browser using [ONNX Runtime Web](https://github.com/microsoft/onnxruntime) with WebGPU. + +You can try out the live demo [here](https://guschmue.github.io/ort-webgpu/chat/index.html). + +We keep this example simple and use the onnxruntime-web api directly without a +higher level framework like [transformers.js](https://github.com/xenova/transformers.js). + +## Getting Started + +### Prerequisites + +Ensure that you have [Node.js](https://nodejs.org/) installed on your machine. + +### Installation + +Install the required dependencies: + +```sh +npm install +``` + +### Building the project + +Build the project: + +```sh +npm run build +``` + +The output can be found in the ***dist*** directory. + +### Building for developent + +```sh +npm run dev +``` + +This will build the project and start a dev server. +Point your browser to http://localhost:8080/. + +### The Phi3 ONNX Model + +The model used in this example is hosted on [Hugging Face](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx-web). It is slightly different than the ONNX model for CUDA or CPU: +1. The model output 'logits' is kept as float32 (even for float16 models) since Javascript does not support float16. +2. Our WebGPU implementation uses the custom Multiheaded Attention operator instread of Group Query Attention. +3. Phi3 is larger then 2GB and we need to use external data files. To keep them cacheable in the browser, + both model.onnx and model.onnx.data are kept under 2GB. + +The model was created using the [ONNX genai model builder](https://github.com/microsoft/onnxruntime-genai/tree/main/src/python/py/models). + +If you like to create the model yourself, you can use [Olive](https://github.com/microsoft/Olive/). +An example how to create the model for ONNX Runtime Web with Olive can be found [here](https://github.com/microsoft/Olive/tree/main/examples/phi3). diff --git a/js/chat/index.html b/js/chat/index.html new file mode 100644 index 000000000..6886c7459 --- /dev/null +++ b/js/chat/index.html @@ -0,0 +1,43 @@ + + + + + + + + + + Chat with onnxruntime-web + + + +
+ +
+
+
+

Chat with onnxruntime-web

+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ + +
+
+ + + + + \ No newline at end of file diff --git a/js/chat/llm.js b/js/chat/llm.js new file mode 100644 index 000000000..3bf51090b --- /dev/null +++ b/js/chat/llm.js @@ -0,0 +1,226 @@ +import * as ort from 'onnxruntime-web/webgpu'; + +ort.env.wasm.numThreads = 1; +ort.env.wasm.simd = true; +ort.env.wasm.wasmPaths = document.location.pathname.replace('index.html', '') + 'dist/'; + + +function log(i) { console.log(i); document.getElementById('status').innerText += `\n${i}`; } + +// +// load file from server or cache +// +async function fetchAndCache(url) { + try { + const cache = await caches.open("onnx"); + let cachedResponse = await cache.match(url); + if (cachedResponse === undefined) { + log(`${url} (network)`); + const buffer = await fetch(url).then(response => response.arrayBuffer()); + try { + await cache.put(url, new Response(buffer)); + } catch (error) { + console.error(error); + } + return buffer; + } + log(`${url} (cached)`); + const data = await cachedResponse.arrayBuffer(); + return data; + } catch (error) { + log(`can't fetch ${url}`); + throw error; + } +} + +// +// class to handle a large language model on top of onnxruntime-web +// +export class LLM { + sess = undefined; + profiler = false; + feed = {}; + output_tokens = []; + eos = 2; + need_position_ids = true; + stop = false; + kv_dims = []; + dtype = "float16"; + max_tokens = 9999; + + constructor() { + } + + async load(model, options) { + const provider = options.provider || "webgpu"; + const verbose = options.verbose; + const local = options.local; + const hasFP16 = (provider === "wasm") ? false : options.hasFP16; + this.profiler = options.profiler; + + const model_path = (local) ? "models/" + model.path : "https://huggingface.co/" + model.path + "/resolve/main"; + let model_file = model.file || "model"; + model_file = (hasFP16) ? model_file + "_q4f16.onnx" : model_file + "_q4.onnx"; + + log(`loading... ${model.name}, ${provider}`); + const json_bytes = await fetchAndCache(model_path + "/config.json"); + let textDecoder = new TextDecoder(); + const model_config = JSON.parse(textDecoder.decode(json_bytes)); + + const model_bytes = await fetchAndCache(model_path + "/onnx/" + model_file); + const externaldata = (model.externaldata) ? await fetchAndCache(model_path + "/onnx/" + model_file + '_data') : false; + let modelSize = model_bytes.byteLength; + if (externaldata) { + modelSize += externaldata.byteLength; + } + log(`model size ${Math.round(modelSize / 1024 / 1024)} MB`); + + const opt = { + executionProviders: [provider], + preferredOutputLocation: {}, + } + + switch (provider) { + case "webgpu": + for (let i = 0; i < model_config.num_hidden_layers; ++i) { + opt.preferredOutputLocation[`present.${i}.key`] = 'gpu-buffer'; + opt.preferredOutputLocation[`present.${i}.value`] = 'gpu-buffer'; + } + break; + } + + if (externaldata !== undefined) { + opt.externalData = [ + { + data: externaldata, + path: model_file + "_data", + }, + ] + } + if (verbose) { + opt.logSeverityLevel = 0; + opt.logVerbosityLevel = 0; + ort.env.logLevel = "verbose"; + } + + ort.env.webgpu.profiling = {} + if (this.profiler) { + opt.enableProfiling = true; + ort.env.webgpu.profilingMode = 'default'; + ort.env.webgpu.profiling.mode = 'default'; + } + + this.sess = await ort.InferenceSession.create(model_bytes, opt); + this.eos = model_config.eos_token_id; + this.kv_dims = [1, model_config.num_key_value_heads, 0, model_config.hidden_size / model_config.num_attention_heads]; + this.dtype = (hasFP16) ? "float16" : "float32"; + this.num_layers = model_config.num_hidden_layers; + this.initilize_feed(); + } + + initilize_feed() { + const feed = this.feed; + + // dispose of previous gpu buffers + for (const name in feed) { + const t = feed[name]; + if (t.location === 'gpu-buffer') { + t.dispose(); + } + } + this.feed = {}; + // key value cache is zero copy, just pass gpu buffer as referece + const empty = (this.dtype === "float16") ? new Uint16Array() : []; + for (let i = 0; i < this.num_layers; ++i) { + this.feed[`past_key_values.${i}.key`] = new ort.Tensor(this.dtype, empty, this.kv_dims) + this.feed[`past_key_values.${i}.value`] = new ort.Tensor(this.dtype, empty, this.kv_dims) + } + this.output_tokens = []; + } + + // + // poor mens argmax + argmax(t) { + const arr = t.data; + const start = t.dims[2] * (t.dims[1] - 1); + let max = arr[start]; + let maxidx = 0; + + for (let i = 0; i < t.dims[2]; i++) { + const val = arr[i + start]; + if (!isFinite(val)) { + throw new Error("found infinitive in logits"); + } + if (val > max) { + max = arr[i + start]; + maxidx = i; + } + } + return maxidx; + } + + // + // update key value cache + // + update_kv_cache(feed, outputs) { + for (const name in outputs) { + if (name.startsWith('present')) { + let newName = name.replace('present', 'past_key_values'); + // dispose previous gpu buffers + const t = feed[newName]; + if (t.location === 'gpu-buffer') { + t.dispose(); + } + feed[newName] = outputs[name]; + } + } + } + + // + // tell generate to stop() + // + abort() { + this.stop = true; + } + + // + // prefill prompt and generate tokens, greedy search only + // + async generate(tokens, callback, options) { + const max_tokens = options.max_tokens || 256; + const feed = this.feed; + const input_ids = new ort.Tensor('int64', BigInt64Array.from(tokens.map(BigInt)), [1, tokens.length]); + feed['input_ids'] = input_ids; + this.stop = false; + + this.output_tokens.push(...input_ids.data); + + let last_token = 0n; + let seqlen = this.output_tokens.length; + const input_len = input_ids.size; + + if (this.need_position_ids) { + feed['position_ids'] = new ort.Tensor('int64', BigInt64Array.from({ length: input_len }, (_, i) => BigInt(seqlen - input_len + i)), [1, input_len]); + } + + while (last_token != this.eos && last_token != 32007 && seqlen < max_tokens && !this.stop) { + seqlen = this.output_tokens.length; + feed['attention_mask'] = new ort.Tensor('int64', BigInt64Array.from({ length: seqlen }, () => 1n), [1, seqlen]); + const outputs = await this.sess.run(feed); + last_token = BigInt(this.argmax(outputs.logits)); + this.output_tokens.push(last_token); + if (callback && !this.profiler) { + callback(this.output_tokens); + } + this.update_kv_cache(feed, outputs); + feed['input_ids'] = new ort.Tensor('int64', BigInt64Array.from([last_token]), [1, 1]); + if (this.need_position_ids) { + feed['position_ids'] = new ort.Tensor('int64', BigInt64Array.from([BigInt(seqlen)]), [1, 1]); + } + } + if (this.profiler) { + this.sess.endProfiling(); + } + return this.output_tokens; + } +} diff --git a/js/chat/main.css b/js/chat/main.css new file mode 100644 index 000000000..d8088468e --- /dev/null +++ b/js/chat/main.css @@ -0,0 +1,57 @@ +body { + color: #f5f5f5; + font-family: 'Arial', sans-serif; +} + +.user-message { + background-color: rgb(86, 144, 163); + color: white; + padding: 10px; + border-radius: 10px; + white-space: pre-wrap; + width: fit-content; +} + +.response-message { + background-color: rgb(62, 62, 62); + color: white; + padding: 10px; + border-radius: 10px; + padding-right: 20px; + position: relative; + margin-right: auto; +} + +.response-message p { + margin-right: 40px; +} + +#chat-container { + display: none; + margin: 0 auto; + overflow: auto; +} + +#chat-history { + display: flex; + flex-direction: column; +} + +.copy-button { + position: absolute; + bottom: 5px; + right: 5px; + margin: 0 5px 5px 0; +} + +#scroll-wrapper { + padding-bottom: 5.5rem; +} + +#input-area { + position: fixed; + bottom: 0; + margin-bottom: 5px; + left: 50%; + transform: translateX(-50%); +} \ No newline at end of file diff --git a/js/chat/main.js b/js/chat/main.js new file mode 100644 index 000000000..fa92ddb80 --- /dev/null +++ b/js/chat/main.js @@ -0,0 +1,306 @@ +import { env, AutoTokenizer } from '@xenova/transformers'; +import { LLM } from './llm.js'; +import { marked } from 'marked'; + + +const MODELS = { + "phi3": { name: "phi3", path: "microsoft/Phi-3-mini-4k-instruct-onnx-web", externaldata: true }, + "phi3dev": { name: "phi3dev", path: "schmuell/Phi-3-mini-4k-instruct-onnx-web", externaldata: true }, +} + +const preCannedQueries = { + "1": "Tell me about the lighthouse of Alexandria.", + "2": "Did the lighthouse of Alexandria existed at the same time the library of Alexandria existed?", + "3": "How did the Pharos lighthouse impact ancient maritime trade?", + "4": "Tell me about Constantinople.", +}; + +const clipboardIcon = ` + + +` + +function log(i) { console.log(i); document.getElementById('status').innerText += `\n${i}`; } + +marked.use({ mangle: false, headerIds: false }); + +const sendButton = document.getElementById('send-button'); +const scrollWrapper = document.getElementById('scroll-wrapper'); + +// +// auto scroll the content area until a user scrolls up +// +let isAutoScrollOn = true; +let lastKnownScrollPosition = 0; +let ticking = false; + +const autoScroller = new ResizeObserver(() => { + if (isAutoScrollOn) { + scrollWrapper.scrollIntoView({ behavior: "smooth", block: "end" }); + } +}); + +document.addEventListener("scroll", () => { + if (!ticking && isAutoScrollOn && window.scrollY < lastKnownScrollPosition) { + window.requestAnimationFrame(() => { + isAutoScrollOn = false; + ticking = false; + }); + ticking = true; + } + else if (!ticking && !isAutoScrollOn && window.scrollY > lastKnownScrollPosition && + window.scrollY >= document.documentElement.scrollHeight - window.innerHeight - 30) { + window.requestAnimationFrame(() => { + isAutoScrollOn = true; + ticking = false; + }); + ticking = true; + } + lastKnownScrollPosition = window.scrollY; +}); + + +// +// make response available for copying to clipboard +// +function copyTextToClipboard(responseDiv) { + let elem = responseDiv; + const copyButton = document.createElement('button'); + copyButton.className = 'btn btn-secondary copy-button'; + copyButton.innerHTML = clipboardIcon; + elem = copyButton; + elem.onclick = () => { + navigator.clipboard.writeText(responseDiv.innerText); + }; + responseDiv.appendChild(elem); +} + +// +// user hits send, enter or ctl enter +// +async function submitRequest(e) { + if (sendButton.innerHTML == "Stop") { + llm.abort(); + return; + } + + // enter clears the chat history, ctl enter will continue the conversation + const continuation = e.ctrlKey && e.key === 'Enter'; + + document.getElementById('chat-container').style.display = 'block'; + + let input = document.getElementById('user-input').value; + if (input.length == 0) { + document.getElementById('chat-history').context = ""; + let chatHistory = document.getElementById('chat-history'); + while (chatHistory.firstChild) { + chatHistory.firstChild.remove(); + } + return; + } + let context = document.getElementById('chat-history').context; + if (context === undefined) { + context = ""; + } + + // append to chat history + let chatHistory = document.getElementById('chat-history'); + let userMessageDiv = document.createElement('div'); + userMessageDiv.className = 'mb-2 user-message'; + userMessageDiv.innerText = input; + chatHistory.appendChild(userMessageDiv); + + // container for llm response + let responseDiv = document.createElement('div'); + responseDiv.className = 'response-message mb-2 text-start'; + responseDiv.style.minHeight = '3em'; + let spinner = document.createElement('div'); + spinner.className = 'spinner-border text-light'; + spinner.setAttribute('role', 'status'); + responseDiv.appendChild(spinner); + chatHistory.appendChild(responseDiv); + + // toggle button to stop text generation + sendButton.innerHTML = "Stop"; + + // change autoScroller to keep track of our new responseDiv + autoScroller.observe(responseDiv); + + if (continuation) { + input = context + " " + input; + } + + Query(continuation, input, (word) => { + responseDiv.innerHTML = marked.parse(word); + }).then(() => { + chatHistory.context = responseDiv.innerHTML; + copyTextToClipboard(responseDiv, true); + sendButton.innerHTML = "Send"; + spinner.remove(); + }).catch(error => { + console.error(error); + sendButton.innerHTML = "Send"; + spinner.remove(); + }); + + // Clear user input + document.getElementById('user-input').value = ''; +} + + +// +// event listener for Ctrl+Enter or Enter +// +document.getElementById('user-input').addEventListener('keydown', function (e) { + if (e.ctrlKey) { + if (e.key === 'Enter') { + submitRequest(e); + } else { + const query = preCannedQueries[e.key]; + if (query) { + document.getElementById('user-input').value = query; + submitRequest(e); + } + } + } else if (e.key === 'Enter') { + e.preventDefault(); + submitRequest(e); + } +}); + +function getConfig() { + const query = window.location.search.substring(1); + var config = { + model: "phi3", + provider: "webgpu", + profiler: 0, + verbose: 0, + threads: 1, + show_special: 0, + csv: 0, + max_tokens: 9999, + local: 0, + } + let vars = query.split("&"); + for (var i = 0; i < vars.length; i++) { + let pair = vars[i].split("="); + if (pair[0] in config) { + const key = pair[0]; + const value = decodeURIComponent(pair[1]); + if (typeof config[key] == "number") { + config[key] = parseInt(value); + } + else { + config[key] = value; + } + } else if (pair[0].length > 0) { + throw new Error("unknown argument: " + pair[0]); + } + } + if (MODELS[config.model] !== undefined) { + config.model = MODELS[config.model]; + } + return config; +} + +const config = getConfig(); + +// setup for transformers.js tokenizer +env.localModelPath = 'models'; +env.allowRemoteModels = config.local == 0; +env.allowLocalModels = config.local == 1; + +let tokenizer; + +const llm = new LLM(); + +function token_to_text(tokenizer, tokens, startidx) { + const txt = tokenizer.decode(tokens.slice(startidx), { skip_special_tokens: config.show_special != 1, }); + return txt; +} + +async function Query(continuation, query, cb) { + let prompt = (continuation) ? query : `<|system|>\nYou are a friendly assistant.<|end|>\n<|user|>\n${query}<|end|>\n<|assistant|>\n`; + + const { input_ids } = await tokenizer(prompt, { return_tensor: false, padding: true, truncation: true }); + + // clear caches + // TODO: use kv_cache for continuation + llm.initilize_feed(); + + const start_timer = performance.now(); + const output_index = llm.output_tokens.length + input_ids.length; + const output_tokens = await llm.generate(input_ids, (output_tokens) => { + if (output_tokens.length == input_ids.length + 1) { + // time to first token + const took = (performance.now() - start_timer) / 1000; + console.log(`time to first token in ${took.toFixed(1)}sec, ${input_ids.length} tokens`); + } + cb(token_to_text(tokenizer, output_tokens, output_index)); + }, { max_tokens: config.max_tokens }); + + const took = (performance.now() - start_timer) / 1000; + cb(token_to_text(tokenizer, output_tokens, output_index)); + const seqlen = output_tokens.length - output_index; + console.log(`${seqlen} tokens in ${took.toFixed(1)}sec, ${(seqlen / took).toFixed(2)} tokens/sec`); +} + +// +// Load the model and tokenizer +// +async function Init(hasFP16) { + try { + tokenizer = await AutoTokenizer.from_pretrained(config.model.path); + + log("Loading model..."); + await llm.load(config.model, { + provider: config.provider, + profiler: config.profiler, + verbose: config.verbose, + local: config.local, + max_tokens: config.max_tokens, + hasFP16: hasFP16, + }); + log("Ready."); + } catch (error) { + log(error); + } +} + +// +// Check if we have webgpu and fp16 +// +async function hasWebGPU() { + // returns 0 for webgpu with f16, 1 for webgpu without f16, 2 for no webgpu + if (!("gpu" in navigator)) { + return 2; + } + try { + const adapter = await navigator.gpu.requestAdapter() + if (adapter.features.has('shader-f16')) { + return 0; + } + return 1; + } catch (e) { + return 2; + } +} + +window.onload = () => { + hasWebGPU().then((supported) => { + if (supported < 2) { + if (supported == 1) { + log("Your GPU or Browser does not support webgpu with fp16, using fp32 instead."); + } + Init(supported === 0).then(() => { + // adjustPadding(); + sendButton.addEventListener('click', submitRequest); + const userInput = document.getElementById('user-input'); + document.getElementById("status").style.display = "none"; + userInput.focus(); + }); + } else { + log("Your GPU or Browser does not support webgpu"); + } + }); +} diff --git a/js/chat/package.json b/js/chat/package.json new file mode 100644 index 000000000..098408a7b --- /dev/null +++ b/js/chat/package.json @@ -0,0 +1,22 @@ +{ + "name": "localchat", + "private": true, + "version": "0.0.0", + "type": "module", + "scripts": { + "dev": "webpack serve --no-client-overlay", + "build": "webpack", + "lint": "eslint . --ext js --report-unused-disable-directives" + }, + "dependencies": { + "@xenova/transformers": "^2.17.1", + "copy-webpack-plugin": "^12.0.2", + "marked": "^12.0.2", + "onnxruntime-web": "1.19.0-dev.20240509-69cfcba38a", + "webpack": "^5.91.0" + }, + "devDependencies": { + "webpack-cli": "^5.1.4", + "webpack-dev-server": "^5.0.4" + } +} diff --git a/js/chat/webpack.config.js b/js/chat/webpack.config.js new file mode 100644 index 000000000..c9b78c380 --- /dev/null +++ b/js/chat/webpack.config.js @@ -0,0 +1,41 @@ +import CopyWebpackPlugin from 'copy-webpack-plugin'; +import { fileURLToPath } from 'url'; +import path from 'path'; + +const __dirname = path.dirname(fileURLToPath(import.meta.url)); + +export default { + mode: 'development', + devtool: 'source-map', + entry: { + 'dist/main': './main.js', + 'dist/main.min': './main.js', + }, + output: { + filename: '[name].js', + path: __dirname, + library: { + type: 'module', + }, + }, + plugins: [ + // Copy .wasm files to dist folder + new CopyWebpackPlugin({ + patterns: [ + { + from: 'node_modules/onnxruntime-web/dist/*.jsep.*', + to: 'dist/[name][ext]' + }, + ], + }), + ], + devServer: { + static: { + directory: __dirname + }, + port: 8080 + }, + experiments: { + outputModule: true, + }, +};