diff --git a/js/README.md b/js/README.md
index de8d378e9..dcd93fbd3 100644
--- a/js/README.md
+++ b/js/README.md
@@ -49,3 +49,6 @@ Click links for README of each examples.
* [Facebook Segment-Anything](segment-anything) - demonstrates how to run [segment-anything](https://github.com/facebookresearch/segment-anything) in your browser using [onnxruntime-web](https://github.com/microsoft/onnxruntime/js) with webgpu.
* [Stable Diffusion Turbo](sd-turbo) - demonstrates how to run [Stable Diffusion Turbo](https://huggingface.co/stabilityai/sd-turbo) in your browser using [onnxruntime-web](https://github.com/microsoft/onnxruntime/js) with webgpu.
+
+* [Phi-3-mini-4k-instruct](chat) - demonstrates how to run [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) in your browser using [onnxruntime-web](https://github.com/microsoft/onnxruntime/js) with webgpu.
+
diff --git a/js/chat/README.md b/js/chat/README.md
new file mode 100644
index 000000000..0b9e2d7f3
--- /dev/null
+++ b/js/chat/README.md
@@ -0,0 +1,54 @@
+# Local Chat using Phi3, ONNX Runtime Web and WebGPU
+
+This repository contains an example of running [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) in your browser using [ONNX Runtime Web](https://github.com/microsoft/onnxruntime) with WebGPU.
+
+You can try out the live demo [here](https://guschmue.github.io/ort-webgpu/chat/index.html).
+
+We keep this example simple and use the onnxruntime-web api directly without a
+higher level framework like [transformers.js](https://github.com/xenova/transformers.js).
+
+## Getting Started
+
+### Prerequisites
+
+Ensure that you have [Node.js](https://nodejs.org/) installed on your machine.
+
+### Installation
+
+Install the required dependencies:
+
+```sh
+npm install
+```
+
+### Building the project
+
+Build the project:
+
+```sh
+npm run build
+```
+
+The output can be found in the ***dist*** directory.
+
+### Building for developent
+
+```sh
+npm run dev
+```
+
+This will build the project and start a dev server.
+Point your browser to http://localhost:8080/.
+
+### The Phi3 ONNX Model
+
+The model used in this example is hosted on [Hugging Face](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx-web). It is slightly different than the ONNX model for CUDA or CPU:
+1. The model output 'logits' is kept as float32 (even for float16 models) since Javascript does not support float16.
+2. Our WebGPU implementation uses the custom Multiheaded Attention operator instread of Group Query Attention.
+3. Phi3 is larger then 2GB and we need to use external data files. To keep them cacheable in the browser,
+ both model.onnx and model.onnx.data are kept under 2GB.
+
+The model was created using the [ONNX genai model builder](https://github.com/microsoft/onnxruntime-genai/tree/main/src/python/py/models).
+
+If you like to create the model yourself, you can use [Olive](https://github.com/microsoft/Olive/).
+An example how to create the model for ONNX Runtime Web with Olive can be found [here](https://github.com/microsoft/Olive/tree/main/examples/phi3).
diff --git a/js/chat/index.html b/js/chat/index.html
new file mode 100644
index 000000000..6886c7459
--- /dev/null
+++ b/js/chat/index.html
@@ -0,0 +1,43 @@
+
+
+
+
+
+
+
+
+
+ Chat with onnxruntime-web
+
+
+
+
+
+
+
+
+
Chat with onnxruntime-web
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/js/chat/llm.js b/js/chat/llm.js
new file mode 100644
index 000000000..3bf51090b
--- /dev/null
+++ b/js/chat/llm.js
@@ -0,0 +1,226 @@
+import * as ort from 'onnxruntime-web/webgpu';
+
+ort.env.wasm.numThreads = 1;
+ort.env.wasm.simd = true;
+ort.env.wasm.wasmPaths = document.location.pathname.replace('index.html', '') + 'dist/';
+
+
+function log(i) { console.log(i); document.getElementById('status').innerText += `\n${i}`; }
+
+//
+// load file from server or cache
+//
+async function fetchAndCache(url) {
+ try {
+ const cache = await caches.open("onnx");
+ let cachedResponse = await cache.match(url);
+ if (cachedResponse === undefined) {
+ log(`${url} (network)`);
+ const buffer = await fetch(url).then(response => response.arrayBuffer());
+ try {
+ await cache.put(url, new Response(buffer));
+ } catch (error) {
+ console.error(error);
+ }
+ return buffer;
+ }
+ log(`${url} (cached)`);
+ const data = await cachedResponse.arrayBuffer();
+ return data;
+ } catch (error) {
+ log(`can't fetch ${url}`);
+ throw error;
+ }
+}
+
+//
+// class to handle a large language model on top of onnxruntime-web
+//
+export class LLM {
+ sess = undefined;
+ profiler = false;
+ feed = {};
+ output_tokens = [];
+ eos = 2;
+ need_position_ids = true;
+ stop = false;
+ kv_dims = [];
+ dtype = "float16";
+ max_tokens = 9999;
+
+ constructor() {
+ }
+
+ async load(model, options) {
+ const provider = options.provider || "webgpu";
+ const verbose = options.verbose;
+ const local = options.local;
+ const hasFP16 = (provider === "wasm") ? false : options.hasFP16;
+ this.profiler = options.profiler;
+
+ const model_path = (local) ? "models/" + model.path : "https://huggingface.co/" + model.path + "/resolve/main";
+ let model_file = model.file || "model";
+ model_file = (hasFP16) ? model_file + "_q4f16.onnx" : model_file + "_q4.onnx";
+
+ log(`loading... ${model.name}, ${provider}`);
+ const json_bytes = await fetchAndCache(model_path + "/config.json");
+ let textDecoder = new TextDecoder();
+ const model_config = JSON.parse(textDecoder.decode(json_bytes));
+
+ const model_bytes = await fetchAndCache(model_path + "/onnx/" + model_file);
+ const externaldata = (model.externaldata) ? await fetchAndCache(model_path + "/onnx/" + model_file + '_data') : false;
+ let modelSize = model_bytes.byteLength;
+ if (externaldata) {
+ modelSize += externaldata.byteLength;
+ }
+ log(`model size ${Math.round(modelSize / 1024 / 1024)} MB`);
+
+ const opt = {
+ executionProviders: [provider],
+ preferredOutputLocation: {},
+ }
+
+ switch (provider) {
+ case "webgpu":
+ for (let i = 0; i < model_config.num_hidden_layers; ++i) {
+ opt.preferredOutputLocation[`present.${i}.key`] = 'gpu-buffer';
+ opt.preferredOutputLocation[`present.${i}.value`] = 'gpu-buffer';
+ }
+ break;
+ }
+
+ if (externaldata !== undefined) {
+ opt.externalData = [
+ {
+ data: externaldata,
+ path: model_file + "_data",
+ },
+ ]
+ }
+ if (verbose) {
+ opt.logSeverityLevel = 0;
+ opt.logVerbosityLevel = 0;
+ ort.env.logLevel = "verbose";
+ }
+
+ ort.env.webgpu.profiling = {}
+ if (this.profiler) {
+ opt.enableProfiling = true;
+ ort.env.webgpu.profilingMode = 'default';
+ ort.env.webgpu.profiling.mode = 'default';
+ }
+
+ this.sess = await ort.InferenceSession.create(model_bytes, opt);
+ this.eos = model_config.eos_token_id;
+ this.kv_dims = [1, model_config.num_key_value_heads, 0, model_config.hidden_size / model_config.num_attention_heads];
+ this.dtype = (hasFP16) ? "float16" : "float32";
+ this.num_layers = model_config.num_hidden_layers;
+ this.initilize_feed();
+ }
+
+ initilize_feed() {
+ const feed = this.feed;
+
+ // dispose of previous gpu buffers
+ for (const name in feed) {
+ const t = feed[name];
+ if (t.location === 'gpu-buffer') {
+ t.dispose();
+ }
+ }
+ this.feed = {};
+ // key value cache is zero copy, just pass gpu buffer as referece
+ const empty = (this.dtype === "float16") ? new Uint16Array() : [];
+ for (let i = 0; i < this.num_layers; ++i) {
+ this.feed[`past_key_values.${i}.key`] = new ort.Tensor(this.dtype, empty, this.kv_dims)
+ this.feed[`past_key_values.${i}.value`] = new ort.Tensor(this.dtype, empty, this.kv_dims)
+ }
+ this.output_tokens = [];
+ }
+
+ //
+ // poor mens argmax
+ argmax(t) {
+ const arr = t.data;
+ const start = t.dims[2] * (t.dims[1] - 1);
+ let max = arr[start];
+ let maxidx = 0;
+
+ for (let i = 0; i < t.dims[2]; i++) {
+ const val = arr[i + start];
+ if (!isFinite(val)) {
+ throw new Error("found infinitive in logits");
+ }
+ if (val > max) {
+ max = arr[i + start];
+ maxidx = i;
+ }
+ }
+ return maxidx;
+ }
+
+ //
+ // update key value cache
+ //
+ update_kv_cache(feed, outputs) {
+ for (const name in outputs) {
+ if (name.startsWith('present')) {
+ let newName = name.replace('present', 'past_key_values');
+ // dispose previous gpu buffers
+ const t = feed[newName];
+ if (t.location === 'gpu-buffer') {
+ t.dispose();
+ }
+ feed[newName] = outputs[name];
+ }
+ }
+ }
+
+ //
+ // tell generate to stop()
+ //
+ abort() {
+ this.stop = true;
+ }
+
+ //
+ // prefill prompt and generate tokens, greedy search only
+ //
+ async generate(tokens, callback, options) {
+ const max_tokens = options.max_tokens || 256;
+ const feed = this.feed;
+ const input_ids = new ort.Tensor('int64', BigInt64Array.from(tokens.map(BigInt)), [1, tokens.length]);
+ feed['input_ids'] = input_ids;
+ this.stop = false;
+
+ this.output_tokens.push(...input_ids.data);
+
+ let last_token = 0n;
+ let seqlen = this.output_tokens.length;
+ const input_len = input_ids.size;
+
+ if (this.need_position_ids) {
+ feed['position_ids'] = new ort.Tensor('int64', BigInt64Array.from({ length: input_len }, (_, i) => BigInt(seqlen - input_len + i)), [1, input_len]);
+ }
+
+ while (last_token != this.eos && last_token != 32007 && seqlen < max_tokens && !this.stop) {
+ seqlen = this.output_tokens.length;
+ feed['attention_mask'] = new ort.Tensor('int64', BigInt64Array.from({ length: seqlen }, () => 1n), [1, seqlen]);
+ const outputs = await this.sess.run(feed);
+ last_token = BigInt(this.argmax(outputs.logits));
+ this.output_tokens.push(last_token);
+ if (callback && !this.profiler) {
+ callback(this.output_tokens);
+ }
+ this.update_kv_cache(feed, outputs);
+ feed['input_ids'] = new ort.Tensor('int64', BigInt64Array.from([last_token]), [1, 1]);
+ if (this.need_position_ids) {
+ feed['position_ids'] = new ort.Tensor('int64', BigInt64Array.from([BigInt(seqlen)]), [1, 1]);
+ }
+ }
+ if (this.profiler) {
+ this.sess.endProfiling();
+ }
+ return this.output_tokens;
+ }
+}
diff --git a/js/chat/main.css b/js/chat/main.css
new file mode 100644
index 000000000..d8088468e
--- /dev/null
+++ b/js/chat/main.css
@@ -0,0 +1,57 @@
+body {
+ color: #f5f5f5;
+ font-family: 'Arial', sans-serif;
+}
+
+.user-message {
+ background-color: rgb(86, 144, 163);
+ color: white;
+ padding: 10px;
+ border-radius: 10px;
+ white-space: pre-wrap;
+ width: fit-content;
+}
+
+.response-message {
+ background-color: rgb(62, 62, 62);
+ color: white;
+ padding: 10px;
+ border-radius: 10px;
+ padding-right: 20px;
+ position: relative;
+ margin-right: auto;
+}
+
+.response-message p {
+ margin-right: 40px;
+}
+
+#chat-container {
+ display: none;
+ margin: 0 auto;
+ overflow: auto;
+}
+
+#chat-history {
+ display: flex;
+ flex-direction: column;
+}
+
+.copy-button {
+ position: absolute;
+ bottom: 5px;
+ right: 5px;
+ margin: 0 5px 5px 0;
+}
+
+#scroll-wrapper {
+ padding-bottom: 5.5rem;
+}
+
+#input-area {
+ position: fixed;
+ bottom: 0;
+ margin-bottom: 5px;
+ left: 50%;
+ transform: translateX(-50%);
+}
\ No newline at end of file
diff --git a/js/chat/main.js b/js/chat/main.js
new file mode 100644
index 000000000..fa92ddb80
--- /dev/null
+++ b/js/chat/main.js
@@ -0,0 +1,306 @@
+import { env, AutoTokenizer } from '@xenova/transformers';
+import { LLM } from './llm.js';
+import { marked } from 'marked';
+
+
+const MODELS = {
+ "phi3": { name: "phi3", path: "microsoft/Phi-3-mini-4k-instruct-onnx-web", externaldata: true },
+ "phi3dev": { name: "phi3dev", path: "schmuell/Phi-3-mini-4k-instruct-onnx-web", externaldata: true },
+}
+
+const preCannedQueries = {
+ "1": "Tell me about the lighthouse of Alexandria.",
+ "2": "Did the lighthouse of Alexandria existed at the same time the library of Alexandria existed?",
+ "3": "How did the Pharos lighthouse impact ancient maritime trade?",
+ "4": "Tell me about Constantinople.",
+};
+
+const clipboardIcon = `
+
+
+ `
+
+function log(i) { console.log(i); document.getElementById('status').innerText += `\n${i}`; }
+
+marked.use({ mangle: false, headerIds: false });
+
+const sendButton = document.getElementById('send-button');
+const scrollWrapper = document.getElementById('scroll-wrapper');
+
+//
+// auto scroll the content area until a user scrolls up
+//
+let isAutoScrollOn = true;
+let lastKnownScrollPosition = 0;
+let ticking = false;
+
+const autoScroller = new ResizeObserver(() => {
+ if (isAutoScrollOn) {
+ scrollWrapper.scrollIntoView({ behavior: "smooth", block: "end" });
+ }
+});
+
+document.addEventListener("scroll", () => {
+ if (!ticking && isAutoScrollOn && window.scrollY < lastKnownScrollPosition) {
+ window.requestAnimationFrame(() => {
+ isAutoScrollOn = false;
+ ticking = false;
+ });
+ ticking = true;
+ }
+ else if (!ticking && !isAutoScrollOn && window.scrollY > lastKnownScrollPosition &&
+ window.scrollY >= document.documentElement.scrollHeight - window.innerHeight - 30) {
+ window.requestAnimationFrame(() => {
+ isAutoScrollOn = true;
+ ticking = false;
+ });
+ ticking = true;
+ }
+ lastKnownScrollPosition = window.scrollY;
+});
+
+
+//
+// make response available for copying to clipboard
+//
+function copyTextToClipboard(responseDiv) {
+ let elem = responseDiv;
+ const copyButton = document.createElement('button');
+ copyButton.className = 'btn btn-secondary copy-button';
+ copyButton.innerHTML = clipboardIcon;
+ elem = copyButton;
+ elem.onclick = () => {
+ navigator.clipboard.writeText(responseDiv.innerText);
+ };
+ responseDiv.appendChild(elem);
+}
+
+//
+// user hits send, enter or ctl enter
+//
+async function submitRequest(e) {
+ if (sendButton.innerHTML == "Stop") {
+ llm.abort();
+ return;
+ }
+
+ // enter clears the chat history, ctl enter will continue the conversation
+ const continuation = e.ctrlKey && e.key === 'Enter';
+
+ document.getElementById('chat-container').style.display = 'block';
+
+ let input = document.getElementById('user-input').value;
+ if (input.length == 0) {
+ document.getElementById('chat-history').context = "";
+ let chatHistory = document.getElementById('chat-history');
+ while (chatHistory.firstChild) {
+ chatHistory.firstChild.remove();
+ }
+ return;
+ }
+ let context = document.getElementById('chat-history').context;
+ if (context === undefined) {
+ context = "";
+ }
+
+ // append to chat history
+ let chatHistory = document.getElementById('chat-history');
+ let userMessageDiv = document.createElement('div');
+ userMessageDiv.className = 'mb-2 user-message';
+ userMessageDiv.innerText = input;
+ chatHistory.appendChild(userMessageDiv);
+
+ // container for llm response
+ let responseDiv = document.createElement('div');
+ responseDiv.className = 'response-message mb-2 text-start';
+ responseDiv.style.minHeight = '3em';
+ let spinner = document.createElement('div');
+ spinner.className = 'spinner-border text-light';
+ spinner.setAttribute('role', 'status');
+ responseDiv.appendChild(spinner);
+ chatHistory.appendChild(responseDiv);
+
+ // toggle button to stop text generation
+ sendButton.innerHTML = "Stop";
+
+ // change autoScroller to keep track of our new responseDiv
+ autoScroller.observe(responseDiv);
+
+ if (continuation) {
+ input = context + " " + input;
+ }
+
+ Query(continuation, input, (word) => {
+ responseDiv.innerHTML = marked.parse(word);
+ }).then(() => {
+ chatHistory.context = responseDiv.innerHTML;
+ copyTextToClipboard(responseDiv, true);
+ sendButton.innerHTML = "Send";
+ spinner.remove();
+ }).catch(error => {
+ console.error(error);
+ sendButton.innerHTML = "Send";
+ spinner.remove();
+ });
+
+ // Clear user input
+ document.getElementById('user-input').value = '';
+}
+
+
+//
+// event listener for Ctrl+Enter or Enter
+//
+document.getElementById('user-input').addEventListener('keydown', function (e) {
+ if (e.ctrlKey) {
+ if (e.key === 'Enter') {
+ submitRequest(e);
+ } else {
+ const query = preCannedQueries[e.key];
+ if (query) {
+ document.getElementById('user-input').value = query;
+ submitRequest(e);
+ }
+ }
+ } else if (e.key === 'Enter') {
+ e.preventDefault();
+ submitRequest(e);
+ }
+});
+
+function getConfig() {
+ const query = window.location.search.substring(1);
+ var config = {
+ model: "phi3",
+ provider: "webgpu",
+ profiler: 0,
+ verbose: 0,
+ threads: 1,
+ show_special: 0,
+ csv: 0,
+ max_tokens: 9999,
+ local: 0,
+ }
+ let vars = query.split("&");
+ for (var i = 0; i < vars.length; i++) {
+ let pair = vars[i].split("=");
+ if (pair[0] in config) {
+ const key = pair[0];
+ const value = decodeURIComponent(pair[1]);
+ if (typeof config[key] == "number") {
+ config[key] = parseInt(value);
+ }
+ else {
+ config[key] = value;
+ }
+ } else if (pair[0].length > 0) {
+ throw new Error("unknown argument: " + pair[0]);
+ }
+ }
+ if (MODELS[config.model] !== undefined) {
+ config.model = MODELS[config.model];
+ }
+ return config;
+}
+
+const config = getConfig();
+
+// setup for transformers.js tokenizer
+env.localModelPath = 'models';
+env.allowRemoteModels = config.local == 0;
+env.allowLocalModels = config.local == 1;
+
+let tokenizer;
+
+const llm = new LLM();
+
+function token_to_text(tokenizer, tokens, startidx) {
+ const txt = tokenizer.decode(tokens.slice(startidx), { skip_special_tokens: config.show_special != 1, });
+ return txt;
+}
+
+async function Query(continuation, query, cb) {
+ let prompt = (continuation) ? query : `<|system|>\nYou are a friendly assistant.<|end|>\n<|user|>\n${query}<|end|>\n<|assistant|>\n`;
+
+ const { input_ids } = await tokenizer(prompt, { return_tensor: false, padding: true, truncation: true });
+
+ // clear caches
+ // TODO: use kv_cache for continuation
+ llm.initilize_feed();
+
+ const start_timer = performance.now();
+ const output_index = llm.output_tokens.length + input_ids.length;
+ const output_tokens = await llm.generate(input_ids, (output_tokens) => {
+ if (output_tokens.length == input_ids.length + 1) {
+ // time to first token
+ const took = (performance.now() - start_timer) / 1000;
+ console.log(`time to first token in ${took.toFixed(1)}sec, ${input_ids.length} tokens`);
+ }
+ cb(token_to_text(tokenizer, output_tokens, output_index));
+ }, { max_tokens: config.max_tokens });
+
+ const took = (performance.now() - start_timer) / 1000;
+ cb(token_to_text(tokenizer, output_tokens, output_index));
+ const seqlen = output_tokens.length - output_index;
+ console.log(`${seqlen} tokens in ${took.toFixed(1)}sec, ${(seqlen / took).toFixed(2)} tokens/sec`);
+}
+
+//
+// Load the model and tokenizer
+//
+async function Init(hasFP16) {
+ try {
+ tokenizer = await AutoTokenizer.from_pretrained(config.model.path);
+
+ log("Loading model...");
+ await llm.load(config.model, {
+ provider: config.provider,
+ profiler: config.profiler,
+ verbose: config.verbose,
+ local: config.local,
+ max_tokens: config.max_tokens,
+ hasFP16: hasFP16,
+ });
+ log("Ready.");
+ } catch (error) {
+ log(error);
+ }
+}
+
+//
+// Check if we have webgpu and fp16
+//
+async function hasWebGPU() {
+ // returns 0 for webgpu with f16, 1 for webgpu without f16, 2 for no webgpu
+ if (!("gpu" in navigator)) {
+ return 2;
+ }
+ try {
+ const adapter = await navigator.gpu.requestAdapter()
+ if (adapter.features.has('shader-f16')) {
+ return 0;
+ }
+ return 1;
+ } catch (e) {
+ return 2;
+ }
+}
+
+window.onload = () => {
+ hasWebGPU().then((supported) => {
+ if (supported < 2) {
+ if (supported == 1) {
+ log("Your GPU or Browser does not support webgpu with fp16, using fp32 instead.");
+ }
+ Init(supported === 0).then(() => {
+ // adjustPadding();
+ sendButton.addEventListener('click', submitRequest);
+ const userInput = document.getElementById('user-input');
+ document.getElementById("status").style.display = "none";
+ userInput.focus();
+ });
+ } else {
+ log("Your GPU or Browser does not support webgpu");
+ }
+ });
+}
diff --git a/js/chat/package.json b/js/chat/package.json
new file mode 100644
index 000000000..098408a7b
--- /dev/null
+++ b/js/chat/package.json
@@ -0,0 +1,22 @@
+{
+ "name": "localchat",
+ "private": true,
+ "version": "0.0.0",
+ "type": "module",
+ "scripts": {
+ "dev": "webpack serve --no-client-overlay",
+ "build": "webpack",
+ "lint": "eslint . --ext js --report-unused-disable-directives"
+ },
+ "dependencies": {
+ "@xenova/transformers": "^2.17.1",
+ "copy-webpack-plugin": "^12.0.2",
+ "marked": "^12.0.2",
+ "onnxruntime-web": "1.19.0-dev.20240509-69cfcba38a",
+ "webpack": "^5.91.0"
+ },
+ "devDependencies": {
+ "webpack-cli": "^5.1.4",
+ "webpack-dev-server": "^5.0.4"
+ }
+}
diff --git a/js/chat/webpack.config.js b/js/chat/webpack.config.js
new file mode 100644
index 000000000..c9b78c380
--- /dev/null
+++ b/js/chat/webpack.config.js
@@ -0,0 +1,41 @@
+import CopyWebpackPlugin from 'copy-webpack-plugin';
+import { fileURLToPath } from 'url';
+import path from 'path';
+
+const __dirname = path.dirname(fileURLToPath(import.meta.url));
+
+export default {
+ mode: 'development',
+ devtool: 'source-map',
+ entry: {
+ 'dist/main': './main.js',
+ 'dist/main.min': './main.js',
+ },
+ output: {
+ filename: '[name].js',
+ path: __dirname,
+ library: {
+ type: 'module',
+ },
+ },
+ plugins: [
+ // Copy .wasm files to dist folder
+ new CopyWebpackPlugin({
+ patterns: [
+ {
+ from: 'node_modules/onnxruntime-web/dist/*.jsep.*',
+ to: 'dist/[name][ext]'
+ },
+ ],
+ }),
+ ],
+ devServer: {
+ static: {
+ directory: __dirname
+ },
+ port: 8080
+ },
+ experiments: {
+ outputModule: true,
+ },
+};