From 782f2fa2dae3cf04b17783dbc4e50b1abc8170e3 Mon Sep 17 00:00:00 2001 From: guschmue Date: Fri, 10 May 2024 17:38:21 -0700 Subject: [PATCH 01/12] add phi3 ort-web example --- js/README.md | 3 + js/chat/README.md | 46 ++++++ js/chat/index.html | 43 ++++++ js/chat/llm.js | 210 ++++++++++++++++++++++++++++ js/chat/main.css | 57 ++++++++ js/chat/main.js | 307 +++++++++++++++++++++++++++++++++++++++++ js/chat/package.json | 23 +++ js/chat/vite.config.js | 35 +++++ 8 files changed, 724 insertions(+) create mode 100644 js/chat/README.md create mode 100644 js/chat/index.html create mode 100644 js/chat/llm.js create mode 100644 js/chat/main.css create mode 100644 js/chat/main.js create mode 100644 js/chat/package.json create mode 100644 js/chat/vite.config.js diff --git a/js/README.md b/js/README.md index de8d378e9..dcd93fbd3 100644 --- a/js/README.md +++ b/js/README.md @@ -49,3 +49,6 @@ Click links for README of each examples. * [Facebook Segment-Anything](segment-anything) - demonstrates how to run [segment-anything](https://github.com/facebookresearch/segment-anything) in your browser using [onnxruntime-web](https://github.com/microsoft/onnxruntime/js) with webgpu. * [Stable Diffusion Turbo](sd-turbo) - demonstrates how to run [Stable Diffusion Turbo](https://huggingface.co/stabilityai/sd-turbo) in your browser using [onnxruntime-web](https://github.com/microsoft/onnxruntime/js) with webgpu. + +* [Phi-3-mini-4k-instruct](chat) - demonstrates how to run [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) in your browser using [onnxruntime-web](https://github.com/microsoft/onnxruntime/js) with webgpu. + diff --git a/js/chat/README.md b/js/chat/README.md new file mode 100644 index 000000000..7431be566 --- /dev/null +++ b/js/chat/README.md @@ -0,0 +1,46 @@ +# Local Chat using Phi3, ONNX Runtime Web and WebGPU + +This repository contains an example of running [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) your browser using [ONNX Runtime Web](https://github.com/microsoft/onnxruntime) with WebGPU. + +You can try out the live demo [here](https://guschmue.github.io/ort-webgpu/chat/index.html). + +## Getting Started + +### Prerequisites + +Ensure that you have [Node.js](https://nodejs.org/) installed on your machine. + +### Installation + +Install the required dependencies: + +```sh +npm install +``` + +### Building the project + +Build the project using vite: + +```sh +npm run build +``` + +The output can be found in the ***dist*** directory. + +### Building for developent +For development you can use vite. +You must run ```npm run build``` once to setup the dist directory. + +```sh +npm run dev +``` + +Point your browser to http://localhost:5173/. + +### The ONNX Model + +The model used in this project is hosted on [Hugging Face](https://huggingface.co/schmuell/phi3-int4). It was created using the [onnx model builder](https://github.com/microsoft/onnxruntime-genai/tree/main/src/python/py/models). + +You create the model with +```python builder.py -m microsoft/Phi-3-mini-4k-instruct -o $your_output -p int4 -e web``` diff --git a/js/chat/index.html b/js/chat/index.html new file mode 100644 index 000000000..03da88dd2 --- /dev/null +++ b/js/chat/index.html @@ -0,0 +1,43 @@ + + + + + + + + + + Chat with onnxruntime-web + + + +
+ +
+
+
+

Chat with onnxruntime-web

+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ + +
+
+ + + + + \ No newline at end of file diff --git a/js/chat/llm.js b/js/chat/llm.js new file mode 100644 index 000000000..33cc1eea4 --- /dev/null +++ b/js/chat/llm.js @@ -0,0 +1,210 @@ +import * as ort from 'onnxruntime-web/webgpu'; + +ort.env.wasm.numThreads = 1; +ort.env.wasm.simd = true; +ort.env.wasm.wasmPaths = document.location.pathname.replace('index.html', '') + 'dist/'; + + +function log(i) { console.log(i); document.getElementById('status').innerText += `\n${i}`; } + +// +// load file from server or cache +// +async function fetchAndCache(url) { + try { + const cache = await caches.open("onnx"); + let cachedResponse = await cache.match(url); + if (cachedResponse === undefined) { + log(`${url} (network)`); + const buffer = await fetch(url).then(response => response.arrayBuffer()); + try { + await cache.put(url, new Response(buffer)); + } catch (error) { + console.error(error); + } + return buffer; + } + log(`${url} (cached)`); + const data = await cachedResponse.arrayBuffer(); + return data; + } catch (error) { + log(`can't fetch ${url}`); + throw error; + } +} + +export class LLM { + sess = undefined; + profiler = false; + feed = {}; + output_tokens = []; + eos = 2; + need_position_ids = true; + stop = false; + kv_dims = []; + dtype = "float16"; + max_tokens = 9999; + + constructor() { + } + + async load(model, options) { + const provider = options.provider || "webgpu"; + const verbose = options.verbose; + const local = options.local; + const hasFP16 = (provider === "wasm") ? false : options.hasFP16; + this.profiler = options.profiler; + + const model_path = (local) ? "models/" + model.path : "https://huggingface.co/" + model.path + "/resolve/main"; + let model_file = model.file || "model"; + model_file = (hasFP16 && model_file != "decoder_model_merged") ? model_file + "_fp16.onnx" : model_file + ".onnx"; + + log(`loading... ${model.name}, ${provider}`); + const json_bytes = await fetchAndCache(model_path + "/config.json"); + let textDecoder = new TextDecoder(); + const model_config = JSON.parse(textDecoder.decode(json_bytes)); + + const model_bytes = await fetchAndCache(model_path + "/onnx/" + model_file); + const externaldata = (model.externaldata) ? await fetchAndCache(model_path + "/onnx/" + model_file + '.data') : false; + let modelSize = model_bytes.byteLength; + if (externaldata) { + modelSize += externaldata.byteLength; + } + log(`model size ${Math.round(modelSize / 1024 / 1024)} MB`); + + const opt = { + executionProviders: [provider], + preferredOutputLocation: {}, + } + + switch (provider) { + case "webgpu": + for (let i = 0; i < model_config.num_hidden_layers; ++i) { + opt.preferredOutputLocation[`present.${i}.key`] = 'gpu-buffer'; + opt.preferredOutputLocation[`present.${i}.value`] = 'gpu-buffer'; + } + break; + } + + if (externaldata !== undefined) { + opt.externalData = [ + { + data: externaldata, + path: model_file + ".data", + }, + ] + } + if (verbose) { + opt.logSeverityLevel = 0; + opt.logVerbosityLevel = 0; + ort.env.logLevel = "verbose"; + } + + ort.env.webgpu.profiling = {} + if (this.profiler) { + opt.enableProfiling = true; + ort.env.webgpu.profilingMode = 'default'; + ort.env.webgpu.profiling.mode = 'default'; + } + + this.sess = await ort.InferenceSession.create(model_bytes, opt); + this.eos = model_config.eos_token_id; + this.kv_dims = [1, model_config.num_key_value_heads, 0, model_config.hidden_size / model_config.num_attention_heads]; + this.dtype = (hasFP16) ? "float16" : "float32"; + this.num_layers = model_config.num_hidden_layers; + this.initilize_feed(); + } + + initilize_feed() { + const feed = this.feed; + for (const name in feed) { + const t = feed[name]; + if (t.location === 'gpu-buffer') { + t.dispose(); + } + } + this.feed = {}; + const empty = (this.dtype === "float16") ? new Uint16Array() : []; + for (let i = 0; i < this.num_layers; ++i) { + this.feed[`past_key_values.${i}.key`] = new ort.Tensor(this.dtype, empty, this.kv_dims) + this.feed[`past_key_values.${i}.value`] = new ort.Tensor(this.dtype, empty, this.kv_dims) + } + this.output_tokens = []; + } + + argmax(t) { + const arr = t.data; + const start = t.dims[2] * (t.dims[1] - 1); + let max = arr[start]; + let maxidx = 0; + + for (let i = 0; i < t.dims[2]; i++) { + const val = arr[i + start]; + if (!isFinite(val)) { + throw new Error("found infinitive in logits"); + } + if (val > max) { + max = arr[i + start]; + maxidx = i; + } + } + return maxidx; + } + + update_kv_cache(feed, outputs) { + for (const name in outputs) { + if (name.startsWith('present')) { + let newName = name.replace('present', 'past_key_values'); + // free old gpu buffer + const t = feed[newName]; + if (t.location === 'gpu-buffer') { + t.dispose(); + } + feed[newName] = outputs[name]; + } + } + } + + abort() { + this.stop = true; + } + + async generate(tokens, callback, options) { + const max_tokens = options.max_tokens || 256; + const feed = this.feed; + const input_ids = new ort.Tensor('int64', BigInt64Array.from(tokens.map(BigInt)), [1, tokens.length]); + feed['input_ids'] = input_ids; + this.stop = false; + + this.output_tokens.push(...input_ids.data); + + let last_token = 0n; + let input_len = input_ids.size; + let seqlen = this.output_tokens.length; + + if (this.need_position_ids) { + feed['position_ids'] = new ort.Tensor('int64', BigInt64Array.from({ length: input_len }, (_, i) => BigInt(seqlen - input_len + i)), [1, input_len]); + } + + while (last_token != this.eos && last_token != 32007 && seqlen < max_tokens && !this.stop) { + seqlen = this.output_tokens.length; + feed['attention_mask'] = new ort.Tensor('int64', BigInt64Array.from({ length: seqlen }, () => 1n), [1, seqlen]); + const outputs = await this.sess.run(feed); + last_token = BigInt(this.argmax(outputs.logits)); + this.output_tokens.push(last_token); + if (callback && !this.profiler) { + callback(this.output_tokens); + } + this.update_kv_cache(feed, outputs); + feed['input_ids'] = new ort.Tensor('int64', BigInt64Array.from([last_token]), [1, 1]); + input_len = 1; + if (this.need_position_ids) { + feed['position_ids'] = new ort.Tensor('int64', BigInt64Array.from([BigInt(seqlen)]), [1, 1]); + } + } + if (this.profiler) { + this.sess.endProfiling(); + } + return this.output_tokens; + } +} diff --git a/js/chat/main.css b/js/chat/main.css new file mode 100644 index 000000000..61d6439bb --- /dev/null +++ b/js/chat/main.css @@ -0,0 +1,57 @@ +body { + color: #f5f5f5; + font-family: 'Arial', sans-serif; +} + +.user-message { + background-color: rgb(86, 144, 163); + color: white; + padding: 10px; + border-radius: 10px; + white-space: pre-wrap; + max-width: 70%; +} + +.response-message { + background-color: rgb(62, 62, 62); + color: white; + padding: 10px; + border-radius: 10px; + padding-right: 20px; + position: relative; + margin-right: auto; +} + +.response-message p { + margin-right: 40px; +} + +#chat-container { + display: none; + margin: 0 auto; + overflow: auto; +} + +#chat-history { + display: flex; + flex-direction: column; +} + +.copy-button { + position: absolute; + bottom: 5px; + right: 5px; + margin: 0 5px 5px 0; +} + +#scroll-wrapper { + padding-bottom: 5.5rem; +} + +#input-area { + position: fixed; + bottom: 0; + margin-bottom: 5px; + left: 50%; + transform: translateX(-50%); +} \ No newline at end of file diff --git a/js/chat/main.js b/js/chat/main.js new file mode 100644 index 000000000..d2441b577 --- /dev/null +++ b/js/chat/main.js @@ -0,0 +1,307 @@ +import { env, AutoTokenizer } from './dist/transformers.js'; +import { LLM } from './llm.js'; +import { marked } from 'marked'; + + +const MODELS = { + "tinyllama": { name: "tinyllama", path: "schmuell/TinyLlama-1.1B-Chat-v1.0-int4", file: "decoder_model_merged" }, + "tinyllama_fp16": { name: "tinyllama-fp16", path: "schmuell/TinyLlama-1.1B-Chat-v1.0-fp16", externaldata: true, file: "decoder_model_merged" }, + "phi2": { name: "phi2", path: "schmuell/phi2-int4", file: "decoder_model_merged" }, + "phi3": { name: "phi3", path: "schmuell/phi3-int4", externaldata: true }, + "phi3-1": { name: "phi3-1", path: "schmuell/phi3-1", externaldata: true }, + "stablelm": { name: "stablelm", path: "schmuell/stablelm-2-zephyr-1_6b-int4", file: "decoder_model_merged" }, +} + +const preCannedQueries = { + "1": "Tell me about the lighthouse of Alexandria.", + "2": "Did the lighthouse of Alexandria existed at the same time the library of Alexandria existed?", + "3": "How did the Pharos lighthouse impact ancient maritime trade?", + "4": "Tell me about Constantinople.", +}; + +const clipboardIcon = ` + + +` + +function log(i) { console.log(i); document.getElementById('status').innerText += `\n${i}`; } + +marked.use({ mangle: false, headerIds: false }); + +const sendButton = document.getElementById('send-button'); +const scrollWrapper = document.getElementById('scroll-wrapper'); + +// +// auto scroll the content area until a user scrolls up +// +let isAutoScrollOn = true; +let lastKnownScrollPosition = 0; +let ticking = false; + +const autoScroller = new ResizeObserver(() => { + if (isAutoScrollOn) { + scrollWrapper.scrollIntoView({ behavior: "smooth", block: "end" }); + } +}); + +document.addEventListener("scroll", () => { + if (!ticking && isAutoScrollOn && window.scrollY < lastKnownScrollPosition) { + window.requestAnimationFrame(() => { + isAutoScrollOn = false; + ticking = false; + }); + ticking = true; + } + else if (!ticking && !isAutoScrollOn && window.scrollY > lastKnownScrollPosition && + window.scrollY >= document.documentElement.scrollHeight - window.innerHeight - 30) { + window.requestAnimationFrame(() => { + isAutoScrollOn = true; + ticking = false; + }); + ticking = true; + } + lastKnownScrollPosition = window.scrollY; +}); + + +// +// make response available for copying to clipboard +// +function copyTextToClipboard(responseDiv) { + let elem = responseDiv; + const copyButton = document.createElement('button'); + copyButton.className = 'btn btn-secondary copy-button'; + copyButton.innerHTML = clipboardIcon; + elem = copyButton; + elem.onclick = () => { + navigator.clipboard.writeText(responseDiv.innerText); + }; + responseDiv.appendChild(elem); +} + +// +// user hits send, enter or ctl enter +// +async function submitRequest(e) { + if (sendButton.innerHTML == "Stop") { + llm.abort(); + return; + } + + // enter clears the chat history, ctl enter will continue the conversation + const continuation = e.ctrlKey && e.key === 'Enter'; + + document.getElementById('chat-container').style.display = 'block'; + + let input = document.getElementById('user-input').value; + if (input.length == 0) { + document.getElementById('chat-history').context = ""; + let chatHistory = document.getElementById('chat-history'); + while (chatHistory.firstChild) { + chatHistory.firstChild.remove(); + } + return; + } + let context = document.getElementById('chat-history').context; + if (context === undefined) { + context = ""; + } + // Create user message element and append to chat history + let chatHistory = document.getElementById('chat-history'); + let userMessageDiv = document.createElement('div'); + userMessageDiv.className = 'mb-2 user-message'; + userMessageDiv.innerText = input; + chatHistory.appendChild(userMessageDiv); + + // Create response container + let responseDiv = document.createElement('div'); + responseDiv.className = 'response-message mb-2 text-start'; + responseDiv.style.minHeight = '3em'; + let spinner = document.createElement('div'); + spinner.className = 'spinner-border text-light'; + spinner.setAttribute('role', 'status'); + responseDiv.appendChild(spinner); + chatHistory.appendChild(responseDiv); + + // toggle button to stop text generation + sendButton.innerHTML = "Stop"; + + // change autoScroller to keep track of our new responseDiv + autoScroller.observe(responseDiv); + + if (continuation) { + input = context + " " + input; + } + + Query(continuation, input, (word) => { + responseDiv.innerHTML = marked.parse(word); + }).then(() => { + chatHistory.context = responseDiv.innerHTML; + copyTextToClipboard(responseDiv, true); + sendButton.innerHTML = "Send"; + spinner.remove(); + }).catch(error => { + console.error(error); + sendButton.innerHTML = "Send"; + spinner.remove(); + }); + + // Clear user input + document.getElementById('user-input').value = ''; +} + + +// +// event listener for Ctrl+Enter or Enter +// +document.getElementById('user-input').addEventListener('keydown', function (e) { + if (e.ctrlKey) { + if (e.key === 'Enter') { + submitRequest(e); + } else { + const query = preCannedQueries[e.key]; + if (query) { + document.getElementById('user-input').value = query; + submitRequest(e); + } + } + } else if (e.key === 'Enter') { + e.preventDefault(); + submitRequest(e); + } +}); + +function getConfig() { + const query = window.location.search.substring(1); + var config = { + model: "phi3", + provider: "webgpu", + profiler: 0, + verbose: 0, + threads: 1, + show_special: 0, + csv: 0, + max_tokens: 9999, + local: 0, + } + let vars = query.split("&"); + for (var i = 0; i < vars.length; i++) { + let pair = vars[i].split("="); + if (pair[0] in config) { + const key = pair[0]; + const value = decodeURIComponent(pair[1]); + if (typeof config[key] == "number") { + config[key] = parseInt(value); + } + else { + config[key] = value; + } + } else if (pair[0].length > 0) { + throw new Error("unknown argument: " + pair[0]); + } + } + if (MODELS[config.model] !== undefined) { + config.model = MODELS[config.model]; + } + return config; +} + +const config = getConfig(); + +// setup for transformers.js tokenizer +env.localModelPath = 'models'; +env.allowRemoteModels = config.local == 0; +env.allowLocalModels = config.local == 1; + +let tokenizer; + +const llm = new LLM(); + +function token_to_text(tokenizer, tokens, startidx) { + const txt = tokenizer.decode(tokens.slice(startidx), { skip_special_tokens: config.show_special != 1, }); + return txt; +} + +async function Query(continuation, query, cb) { + let prompt = (continuation) ? query : `<|system|>\nYou are a friendly assistant.<|end|>\n<|user|>\n${query}<|end|>\n<|assistant|>\n`; + + const { input_ids } = await tokenizer(prompt, { return_tensor: false, padding: true, truncation: true }); + + // clear caches TODO: for continuation we should use kv_cache + llm.initilize_feed(); + + const start_timer = performance.now(); + const output_index = llm.output_tokens.length + input_ids.length; + const output_tokens = await llm.generate(input_ids, (output_tokens) => { + if (output_tokens.length == input_ids.length + 1) { + // time to first token = prefill + const took = (performance.now() - start_timer) / 1000; + console.log(`time to first token in ${took.toFixed(1)}sec, ${input_ids.length} tokens`); + } + cb(token_to_text(tokenizer, output_tokens, output_index)); + }, { max_tokens: config.max_tokens }); + + const took = (performance.now() - start_timer) / 1000; + cb(token_to_text(tokenizer, output_tokens, output_index)); + const seqlen = output_tokens.length - output_index; + console.log(`${seqlen} tokens in ${took.toFixed(1)}sec, ${(seqlen / took).toFixed(2)} tokens/sec`); +} + +// +// Load the model and tokenizer +async function Init(hasFP16) { + try { + tokenizer = await AutoTokenizer.from_pretrained(config.model.path); + + log("Loading model..."); + await llm.load(config.model, { + provider: config.provider, + profiler: config.profiler, + verbose: config.verbose, + local: config.local, + max_tokens: config.max_tokens, + hasFP16: hasFP16, + }); + log("Ready."); + } catch (error) { + log(error); + } +} + +// +// Check if we have webgpu and fp16 +// +async function hasWebGPU() { + // returns 0 for webgpu with f16, 1 for webgpu without f16, 2 for no webgpu + if (!("gpu" in navigator)) { + return 2; + } + try { + const adapter = await navigator.gpu.requestAdapter() + if (adapter.features.has('shader-f16')) { + return 0; + } + return 1; + } catch (e) { + return 2; + } +} + +window.onload = () => { + hasWebGPU().then((supported) => { + if (supported < 2) { + if (supported == 1) { + log("Your GPU or Browser does not support webgpu with fp16, using fp32 instead."); + } + Init(supported === 0).then(() => { + // adjustPadding(); + sendButton.addEventListener('click', submitRequest); + const userInput = document.getElementById('user-input'); + document.getElementById("status").style.display = "none"; + userInput.focus(); + }); + } else { + log("Your GPU or Browser does not support webgpu"); + } + }); +} diff --git a/js/chat/package.json b/js/chat/package.json new file mode 100644 index 000000000..40be13eba --- /dev/null +++ b/js/chat/package.json @@ -0,0 +1,23 @@ +{ + "name": "localchat", + "private": true, + "version": "0.0.0", + "type": "module", + "scripts": { + "dev": "vite", + "build": "vite build", + "lint": "eslint . --ext js --report-unused-disable-directives", + "preview": "vite preview" + }, + "dependencies": { + "@xenova/transformers": "^2.17.1", + "dompurify": "^3.1.2", + "marked": "^12.0.2", + "onnxruntime-web": "^1.19.0-dev.20240509-69cfcba38a" + }, + "devDependencies": { + "eslint": "^8.55.0", + "vite": "^5.2.11", + "vite-plugin-static-copy": "^1.0.4" + } +} diff --git a/js/chat/vite.config.js b/js/chat/vite.config.js new file mode 100644 index 000000000..b39d1dd9e --- /dev/null +++ b/js/chat/vite.config.js @@ -0,0 +1,35 @@ +import { defineConfig } from "vite"; +import * as path from "path"; +import { fileURLToPath } from "node:url"; +import { viteStaticCopy } from 'vite-plugin-static-copy' + +const filesNeedToExclude = ["models", "dist/transformers.js"]; +const filesPathToExclude = filesNeedToExclude.map((src) => { + return fileURLToPath(new URL(src, import.meta.url)); +}); + +export default defineConfig({ + plugins: [ + viteStaticCopy({ + targets: [ + { + src: 'node_modules/onnxruntime-web/dist/*jsep*.wasm', + dest: path.join(__dirname, 'dist/dist') + }, + { + src: 'node_modules/@xenova/transformers/dist/transformers.js', + dest: path.join(__dirname, 'dist/dist') + } + + ] + }) + ], + build: { + outDir: "dist", + rollupOptions: { + external: [ + ...filesPathToExclude + ], + }, + }, +}); \ No newline at end of file From f80ffb80fa625e631c721bbcec91f1fbdea29a7a Mon Sep 17 00:00:00 2001 From: guschmue Date: Fri, 10 May 2024 17:40:31 -0700 Subject: [PATCH 02/12] add phi3 ort-web example --- js/chat/README.md | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/js/chat/README.md b/js/chat/README.md index 7431be566..847444977 100644 --- a/js/chat/README.md +++ b/js/chat/README.md @@ -42,5 +42,9 @@ Point your browser to http://localhost:5173/. The model used in this project is hosted on [Hugging Face](https://huggingface.co/schmuell/phi3-int4). It was created using the [onnx model builder](https://github.com/microsoft/onnxruntime-genai/tree/main/src/python/py/models). -You create the model with -```python builder.py -m microsoft/Phi-3-mini-4k-instruct -o $your_output -p int4 -e web``` +You can create the model with + +```sh +python builder.py -m microsoft/Phi-3-mini-4k-instruct -o $your_output -p int4 -e web +``` + From 4a77bb372d2bb0c227bd4c9483cad5e46d5338c7 Mon Sep 17 00:00:00 2001 From: guschmue Date: Fri, 10 May 2024 17:46:41 -0700 Subject: [PATCH 03/12] add phi3 ort-web example --- js/chat/README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/js/chat/README.md b/js/chat/README.md index 847444977..35635ad71 100644 --- a/js/chat/README.md +++ b/js/chat/README.md @@ -4,6 +4,9 @@ This repository contains an example of running [Phi-3-mini-4k-instruct](https:// You can try out the live demo [here](https://guschmue.github.io/ort-webgpu/chat/index.html). +We keep this example simple and use the onnxruntime-web api directly without a +more advanced framework like [transformers.js](https://github.com/xenova/transformers.js). + ## Getting Started ### Prerequisites From 25dc4f64d0ac2428408b66bbd44c90eeb9d2dcce Mon Sep 17 00:00:00 2001 From: guschmue Date: Fri, 10 May 2024 17:48:45 -0700 Subject: [PATCH 04/12] add phi3 ort-web example --- js/chat/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/chat/README.md b/js/chat/README.md index 35635ad71..94f2d4ef2 100644 --- a/js/chat/README.md +++ b/js/chat/README.md @@ -1,6 +1,6 @@ # Local Chat using Phi3, ONNX Runtime Web and WebGPU -This repository contains an example of running [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) your browser using [ONNX Runtime Web](https://github.com/microsoft/onnxruntime) with WebGPU. +This repository contains an example of running [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) in your browser using [ONNX Runtime Web](https://github.com/microsoft/onnxruntime) with WebGPU. You can try out the live demo [here](https://guschmue.github.io/ort-webgpu/chat/index.html). From d49732f36f8a1bb60984a2fb1c45c32b7f656296 Mon Sep 17 00:00:00 2001 From: guschmue Date: Fri, 10 May 2024 20:02:31 -0700 Subject: [PATCH 05/12] fix ort package version --- js/chat/package.json | 3 +-- js/chat/vite.config.js | 8 ++++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/js/chat/package.json b/js/chat/package.json index 40be13eba..92b70b6fd 100644 --- a/js/chat/package.json +++ b/js/chat/package.json @@ -11,9 +11,8 @@ }, "dependencies": { "@xenova/transformers": "^2.17.1", - "dompurify": "^3.1.2", "marked": "^12.0.2", - "onnxruntime-web": "^1.19.0-dev.20240509-69cfcba38a" + "onnxruntime-web": "^1.19.0-dev.20240505-a36692066d" }, "devDependencies": { "eslint": "^8.55.0", diff --git a/js/chat/vite.config.js b/js/chat/vite.config.js index b39d1dd9e..5f3e0b295 100644 --- a/js/chat/vite.config.js +++ b/js/chat/vite.config.js @@ -16,9 +16,17 @@ export default defineConfig({ src: 'node_modules/onnxruntime-web/dist/*jsep*.wasm', dest: path.join(__dirname, 'dist/dist') }, + { + src: 'node_modules/onnxruntime-web/dist/*jsep*.wasm', + dest: path.join(__dirname, 'dist') + }, { src: 'node_modules/@xenova/transformers/dist/transformers.js', dest: path.join(__dirname, 'dist/dist') + }, + { + src: 'node_modules/@xenova/transformers/dist/transformers.js', + dest: path.join(__dirname, 'dist') } ] From 9c412cf5a60dbfb0784844c9f4b95e52256f757c Mon Sep 17 00:00:00 2001 From: guschmue Date: Mon, 13 May 2024 09:30:07 -0700 Subject: [PATCH 06/12] pin ort, limit width of user messages --- js/chat/main.css | 2 +- js/chat/package.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/js/chat/main.css b/js/chat/main.css index 61d6439bb..d8088468e 100644 --- a/js/chat/main.css +++ b/js/chat/main.css @@ -9,7 +9,7 @@ body { padding: 10px; border-radius: 10px; white-space: pre-wrap; - max-width: 70%; + width: fit-content; } .response-message { diff --git a/js/chat/package.json b/js/chat/package.json index 92b70b6fd..558aad470 100644 --- a/js/chat/package.json +++ b/js/chat/package.json @@ -12,7 +12,7 @@ "dependencies": { "@xenova/transformers": "^2.17.1", "marked": "^12.0.2", - "onnxruntime-web": "^1.19.0-dev.20240505-a36692066d" + "onnxruntime-web": "1.19.0-dev.20240509-69cfcba38a" }, "devDependencies": { "eslint": "^8.55.0", From 83366de316b9c25d647c70524068ffdf7191953e Mon Sep 17 00:00:00 2001 From: guschmue Date: Tue, 14 May 2024 10:18:08 -0700 Subject: [PATCH 07/12] switch to webpack --- js/chat/index.html | 2 +- js/chat/llm.js | 24 ++++++++++++++++++---- js/chat/main.js | 13 +++++++----- js/chat/package.json | 16 +++++++-------- js/chat/vite.config.js | 43 --------------------------------------- js/chat/webpack.config.js | 41 +++++++++++++++++++++++++++++++++++++ 6 files changed, 78 insertions(+), 61 deletions(-) delete mode 100644 js/chat/vite.config.js create mode 100644 js/chat/webpack.config.js diff --git a/js/chat/index.html b/js/chat/index.html index 03da88dd2..6886c7459 100644 --- a/js/chat/index.html +++ b/js/chat/index.html @@ -37,7 +37,7 @@

Chat with onnxruntime-web

- + \ No newline at end of file diff --git a/js/chat/llm.js b/js/chat/llm.js index 33cc1eea4..09d9776f3 100644 --- a/js/chat/llm.js +++ b/js/chat/llm.js @@ -2,7 +2,7 @@ import * as ort from 'onnxruntime-web/webgpu'; ort.env.wasm.numThreads = 1; ort.env.wasm.simd = true; -ort.env.wasm.wasmPaths = document.location.pathname.replace('index.html', '') + 'dist/'; +ort.env.wasm.wasmPaths = document.location.pathname.replace('index.html', 'dist/'); function log(i) { console.log(i); document.getElementById('status').innerText += `\n${i}`; } @@ -33,6 +33,9 @@ async function fetchAndCache(url) { } } +// +// class to handle a large language model on top of onnxruntime-web +// export class LLM { sess = undefined; profiler = false; @@ -117,6 +120,8 @@ export class LLM { initilize_feed() { const feed = this.feed; + + // dispose of previous gpu buffers for (const name in feed) { const t = feed[name]; if (t.location === 'gpu-buffer') { @@ -124,6 +129,7 @@ export class LLM { } } this.feed = {}; + // key value cache is zero copy, just pass gpu buffer as referece const empty = (this.dtype === "float16") ? new Uint16Array() : []; for (let i = 0; i < this.num_layers; ++i) { this.feed[`past_key_values.${i}.key`] = new ort.Tensor(this.dtype, empty, this.kv_dims) @@ -132,6 +138,8 @@ export class LLM { this.output_tokens = []; } + // + // poor mens argmax argmax(t) { const arr = t.data; const start = t.dims[2] * (t.dims[1] - 1); @@ -151,11 +159,14 @@ export class LLM { return maxidx; } + // + // update key value cache + // update_kv_cache(feed, outputs) { for (const name in outputs) { if (name.startsWith('present')) { let newName = name.replace('present', 'past_key_values'); - // free old gpu buffer + // dispose previous gpu buffers const t = feed[newName]; if (t.location === 'gpu-buffer') { t.dispose(); @@ -165,10 +176,16 @@ export class LLM { } } + // + // tell generate to stop() + // abort() { this.stop = true; } + // + // prefill prompt and generate tokens + // async generate(tokens, callback, options) { const max_tokens = options.max_tokens || 256; const feed = this.feed; @@ -179,8 +196,8 @@ export class LLM { this.output_tokens.push(...input_ids.data); let last_token = 0n; - let input_len = input_ids.size; let seqlen = this.output_tokens.length; + const input_len = input_ids.size; if (this.need_position_ids) { feed['position_ids'] = new ort.Tensor('int64', BigInt64Array.from({ length: input_len }, (_, i) => BigInt(seqlen - input_len + i)), [1, input_len]); @@ -197,7 +214,6 @@ export class LLM { } this.update_kv_cache(feed, outputs); feed['input_ids'] = new ort.Tensor('int64', BigInt64Array.from([last_token]), [1, 1]); - input_len = 1; if (this.need_position_ids) { feed['position_ids'] = new ort.Tensor('int64', BigInt64Array.from([BigInt(seqlen)]), [1, 1]); } diff --git a/js/chat/main.js b/js/chat/main.js index d2441b577..d2bd3d231 100644 --- a/js/chat/main.js +++ b/js/chat/main.js @@ -1,4 +1,4 @@ -import { env, AutoTokenizer } from './dist/transformers.js'; +import { env, AutoTokenizer } from '@xenova/transformers'; import { LLM } from './llm.js'; import { marked } from 'marked'; @@ -106,14 +106,15 @@ async function submitRequest(e) { if (context === undefined) { context = ""; } - // Create user message element and append to chat history + + // append to chat history let chatHistory = document.getElementById('chat-history'); let userMessageDiv = document.createElement('div'); userMessageDiv.className = 'mb-2 user-message'; userMessageDiv.innerText = input; chatHistory.appendChild(userMessageDiv); - // Create response container + // container for llm response let responseDiv = document.createElement('div'); responseDiv.className = 'response-message mb-2 text-start'; responseDiv.style.minHeight = '3em'; @@ -227,14 +228,15 @@ async function Query(continuation, query, cb) { const { input_ids } = await tokenizer(prompt, { return_tensor: false, padding: true, truncation: true }); - // clear caches TODO: for continuation we should use kv_cache + // clear caches + // TODO: use kv_cache for continuation llm.initilize_feed(); const start_timer = performance.now(); const output_index = llm.output_tokens.length + input_ids.length; const output_tokens = await llm.generate(input_ids, (output_tokens) => { if (output_tokens.length == input_ids.length + 1) { - // time to first token = prefill + // time to first token const took = (performance.now() - start_timer) / 1000; console.log(`time to first token in ${took.toFixed(1)}sec, ${input_ids.length} tokens`); } @@ -249,6 +251,7 @@ async function Query(continuation, query, cb) { // // Load the model and tokenizer +// async function Init(hasFP16) { try { tokenizer = await AutoTokenizer.from_pretrained(config.model.path); diff --git a/js/chat/package.json b/js/chat/package.json index 558aad470..098408a7b 100644 --- a/js/chat/package.json +++ b/js/chat/package.json @@ -4,19 +4,19 @@ "version": "0.0.0", "type": "module", "scripts": { - "dev": "vite", - "build": "vite build", - "lint": "eslint . --ext js --report-unused-disable-directives", - "preview": "vite preview" + "dev": "webpack serve --no-client-overlay", + "build": "webpack", + "lint": "eslint . --ext js --report-unused-disable-directives" }, "dependencies": { "@xenova/transformers": "^2.17.1", + "copy-webpack-plugin": "^12.0.2", "marked": "^12.0.2", - "onnxruntime-web": "1.19.0-dev.20240509-69cfcba38a" + "onnxruntime-web": "1.19.0-dev.20240509-69cfcba38a", + "webpack": "^5.91.0" }, "devDependencies": { - "eslint": "^8.55.0", - "vite": "^5.2.11", - "vite-plugin-static-copy": "^1.0.4" + "webpack-cli": "^5.1.4", + "webpack-dev-server": "^5.0.4" } } diff --git a/js/chat/vite.config.js b/js/chat/vite.config.js deleted file mode 100644 index 5f3e0b295..000000000 --- a/js/chat/vite.config.js +++ /dev/null @@ -1,43 +0,0 @@ -import { defineConfig } from "vite"; -import * as path from "path"; -import { fileURLToPath } from "node:url"; -import { viteStaticCopy } from 'vite-plugin-static-copy' - -const filesNeedToExclude = ["models", "dist/transformers.js"]; -const filesPathToExclude = filesNeedToExclude.map((src) => { - return fileURLToPath(new URL(src, import.meta.url)); -}); - -export default defineConfig({ - plugins: [ - viteStaticCopy({ - targets: [ - { - src: 'node_modules/onnxruntime-web/dist/*jsep*.wasm', - dest: path.join(__dirname, 'dist/dist') - }, - { - src: 'node_modules/onnxruntime-web/dist/*jsep*.wasm', - dest: path.join(__dirname, 'dist') - }, - { - src: 'node_modules/@xenova/transformers/dist/transformers.js', - dest: path.join(__dirname, 'dist/dist') - }, - { - src: 'node_modules/@xenova/transformers/dist/transformers.js', - dest: path.join(__dirname, 'dist') - } - - ] - }) - ], - build: { - outDir: "dist", - rollupOptions: { - external: [ - ...filesPathToExclude - ], - }, - }, -}); \ No newline at end of file diff --git a/js/chat/webpack.config.js b/js/chat/webpack.config.js new file mode 100644 index 000000000..81562ce49 --- /dev/null +++ b/js/chat/webpack.config.js @@ -0,0 +1,41 @@ +import CopyWebpackPlugin from 'copy-webpack-plugin'; +import { fileURLToPath } from 'url'; +import path from 'path'; + +const __dirname = path.dirname(fileURLToPath(import.meta.url)); + +export default { + mode: 'development', + devtool: 'source-map', + entry: { + 'dist/main': './main.js', + 'dist/main.min': './main.js', + }, + output: { + filename: '[name].js', + path: __dirname, + library: { + type: 'module', + }, + }, + plugins: [ + // Copy .wasm files to dist folder + new CopyWebpackPlugin({ + patterns: [ + { + from: 'node_modules/onnxruntime-web/dist/*.jsep.wasm', + to: 'dist/[name][ext]' + }, + ], + }), + ], + devServer: { + static: { + directory: __dirname + }, + port: 8080 + }, + experiments: { + outputModule: true, + }, +}; From d0a95c5ee5b086e1685e70c993347718d9acc700 Mon Sep 17 00:00:00 2001 From: guschmue Date: Tue, 14 May 2024 10:44:10 -0700 Subject: [PATCH 08/12] update readme to reflect webpack --- js/chat/README.md | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/js/chat/README.md b/js/chat/README.md index 94f2d4ef2..ff4d1dd6a 100644 --- a/js/chat/README.md +++ b/js/chat/README.md @@ -5,7 +5,7 @@ This repository contains an example of running [Phi-3-mini-4k-instruct](https:// You can try out the live demo [here](https://guschmue.github.io/ort-webgpu/chat/index.html). We keep this example simple and use the onnxruntime-web api directly without a -more advanced framework like [transformers.js](https://github.com/xenova/transformers.js). +higher level framework like [transformers.js](https://github.com/xenova/transformers.js). ## Getting Started @@ -23,7 +23,7 @@ npm install ### Building the project -Build the project using vite: +Build the project: ```sh npm run build @@ -32,14 +32,13 @@ npm run build The output can be found in the ***dist*** directory. ### Building for developent -For development you can use vite. -You must run ```npm run build``` once to setup the dist directory. ```sh npm run dev ``` -Point your browser to http://localhost:5173/. +This will build the project and start a dev server. +Point your browser to http://localhost:8080/. ### The ONNX Model @@ -50,4 +49,3 @@ You can create the model with ```sh python builder.py -m microsoft/Phi-3-mini-4k-instruct -o $your_output -p int4 -e web ``` - From df76069bb78588ca8dbfc0030b95bffd7431e586 Mon Sep 17 00:00:00 2001 From: guschmue Date: Tue, 14 May 2024 19:38:58 -0700 Subject: [PATCH 09/12] fix wasm path --- js/chat/llm.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/chat/llm.js b/js/chat/llm.js index 09d9776f3..bdde922cb 100644 --- a/js/chat/llm.js +++ b/js/chat/llm.js @@ -2,7 +2,7 @@ import * as ort from 'onnxruntime-web/webgpu'; ort.env.wasm.numThreads = 1; ort.env.wasm.simd = true; -ort.env.wasm.wasmPaths = document.location.pathname.replace('index.html', 'dist/'); +ort.env.wasm.wasmPaths = document.location.pathname.replace('index.html', '') + 'dist/'; function log(i) { console.log(i); document.getElementById('status').innerText += `\n${i}`; } From 712a887dec070c9a32c7b3d184275528ea4b1e03 Mon Sep 17 00:00:00 2001 From: guschmue Date: Thu, 16 May 2024 10:25:29 -0700 Subject: [PATCH 10/12] add olive instructions to readme --- js/chat/README.md | 15 +++++++++------ js/chat/main.js | 5 ----- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/js/chat/README.md b/js/chat/README.md index ff4d1dd6a..0b9e2d7f3 100644 --- a/js/chat/README.md +++ b/js/chat/README.md @@ -40,12 +40,15 @@ npm run dev This will build the project and start a dev server. Point your browser to http://localhost:8080/. -### The ONNX Model +### The Phi3 ONNX Model -The model used in this project is hosted on [Hugging Face](https://huggingface.co/schmuell/phi3-int4). It was created using the [onnx model builder](https://github.com/microsoft/onnxruntime-genai/tree/main/src/python/py/models). +The model used in this example is hosted on [Hugging Face](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx-web). It is slightly different than the ONNX model for CUDA or CPU: +1. The model output 'logits' is kept as float32 (even for float16 models) since Javascript does not support float16. +2. Our WebGPU implementation uses the custom Multiheaded Attention operator instread of Group Query Attention. +3. Phi3 is larger then 2GB and we need to use external data files. To keep them cacheable in the browser, + both model.onnx and model.onnx.data are kept under 2GB. -You can create the model with +The model was created using the [ONNX genai model builder](https://github.com/microsoft/onnxruntime-genai/tree/main/src/python/py/models). -```sh -python builder.py -m microsoft/Phi-3-mini-4k-instruct -o $your_output -p int4 -e web -``` +If you like to create the model yourself, you can use [Olive](https://github.com/microsoft/Olive/). +An example how to create the model for ONNX Runtime Web with Olive can be found [here](https://github.com/microsoft/Olive/tree/main/examples/phi3). diff --git a/js/chat/main.js b/js/chat/main.js index d2bd3d231..10d7599c0 100644 --- a/js/chat/main.js +++ b/js/chat/main.js @@ -4,12 +4,7 @@ import { marked } from 'marked'; const MODELS = { - "tinyllama": { name: "tinyllama", path: "schmuell/TinyLlama-1.1B-Chat-v1.0-int4", file: "decoder_model_merged" }, - "tinyllama_fp16": { name: "tinyllama-fp16", path: "schmuell/TinyLlama-1.1B-Chat-v1.0-fp16", externaldata: true, file: "decoder_model_merged" }, - "phi2": { name: "phi2", path: "schmuell/phi2-int4", file: "decoder_model_merged" }, "phi3": { name: "phi3", path: "schmuell/phi3-int4", externaldata: true }, - "phi3-1": { name: "phi3-1", path: "schmuell/phi3-1", externaldata: true }, - "stablelm": { name: "stablelm", path: "schmuell/stablelm-2-zephyr-1_6b-int4", file: "decoder_model_merged" }, } const preCannedQueries = { From 46708976705899bbc83bd966dfdd536f486f0ce6 Mon Sep 17 00:00:00 2001 From: guschmue Date: Thu, 16 May 2024 12:12:56 -0700 Subject: [PATCH 11/12] new naming convention and location for the model --- js/chat/llm.js | 8 ++++---- js/chat/main.js | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/js/chat/llm.js b/js/chat/llm.js index bdde922cb..3bf51090b 100644 --- a/js/chat/llm.js +++ b/js/chat/llm.js @@ -60,7 +60,7 @@ export class LLM { const model_path = (local) ? "models/" + model.path : "https://huggingface.co/" + model.path + "/resolve/main"; let model_file = model.file || "model"; - model_file = (hasFP16 && model_file != "decoder_model_merged") ? model_file + "_fp16.onnx" : model_file + ".onnx"; + model_file = (hasFP16) ? model_file + "_q4f16.onnx" : model_file + "_q4.onnx"; log(`loading... ${model.name}, ${provider}`); const json_bytes = await fetchAndCache(model_path + "/config.json"); @@ -68,7 +68,7 @@ export class LLM { const model_config = JSON.parse(textDecoder.decode(json_bytes)); const model_bytes = await fetchAndCache(model_path + "/onnx/" + model_file); - const externaldata = (model.externaldata) ? await fetchAndCache(model_path + "/onnx/" + model_file + '.data') : false; + const externaldata = (model.externaldata) ? await fetchAndCache(model_path + "/onnx/" + model_file + '_data') : false; let modelSize = model_bytes.byteLength; if (externaldata) { modelSize += externaldata.byteLength; @@ -93,7 +93,7 @@ export class LLM { opt.externalData = [ { data: externaldata, - path: model_file + ".data", + path: model_file + "_data", }, ] } @@ -184,7 +184,7 @@ export class LLM { } // - // prefill prompt and generate tokens + // prefill prompt and generate tokens, greedy search only // async generate(tokens, callback, options) { const max_tokens = options.max_tokens || 256; diff --git a/js/chat/main.js b/js/chat/main.js index 10d7599c0..fa92ddb80 100644 --- a/js/chat/main.js +++ b/js/chat/main.js @@ -4,7 +4,8 @@ import { marked } from 'marked'; const MODELS = { - "phi3": { name: "phi3", path: "schmuell/phi3-int4", externaldata: true }, + "phi3": { name: "phi3", path: "microsoft/Phi-3-mini-4k-instruct-onnx-web", externaldata: true }, + "phi3dev": { name: "phi3dev", path: "schmuell/Phi-3-mini-4k-instruct-onnx-web", externaldata: true }, } const preCannedQueries = { From 6bde4f3d2ae2fad3a72f0de2aed652717b0218c3 Mon Sep 17 00:00:00 2001 From: guschmue Date: Thu, 16 May 2024 12:32:15 -0700 Subject: [PATCH 12/12] future proof build config --- js/chat/webpack.config.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/chat/webpack.config.js b/js/chat/webpack.config.js index 81562ce49..c9b78c380 100644 --- a/js/chat/webpack.config.js +++ b/js/chat/webpack.config.js @@ -23,7 +23,7 @@ export default { new CopyWebpackPlugin({ patterns: [ { - from: 'node_modules/onnxruntime-web/dist/*.jsep.wasm', + from: 'node_modules/onnxruntime-web/dist/*.jsep.*', to: 'dist/[name][ext]' }, ],