Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions js/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,6 @@ Click links for README of each examples.
* [Facebook Segment-Anything](segment-anything) - demonstrates how to run [segment-anything](https://github.com/facebookresearch/segment-anything) in your browser using [onnxruntime-web](https://github.com/microsoft/onnxruntime/js) with webgpu.

* [Stable Diffusion Turbo](sd-turbo) - demonstrates how to run [Stable Diffusion Turbo](https://huggingface.co/stabilityai/sd-turbo) in your browser using [onnxruntime-web](https://github.com/microsoft/onnxruntime/js) with webgpu.

* [Phi-3-mini-4k-instruct](chat) - demonstrates how to run [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) in your browser using [onnxruntime-web](https://github.com/microsoft/onnxruntime/js) with webgpu.

54 changes: 54 additions & 0 deletions js/chat/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Local Chat using Phi3, ONNX Runtime Web and WebGPU

This repository contains an example of running [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) in your browser using [ONNX Runtime Web](https://github.com/microsoft/onnxruntime) with WebGPU.

You can try out the live demo [here](https://guschmue.github.io/ort-webgpu/chat/index.html).

We keep this example simple and use the onnxruntime-web api directly without a
higher level framework like [transformers.js](https://github.com/xenova/transformers.js).

## Getting Started

### Prerequisites

Ensure that you have [Node.js](https://nodejs.org/) installed on your machine.

### Installation

Install the required dependencies:

```sh
npm install
```

### Building the project

Build the project:

```sh
npm run build
```

The output can be found in the ***dist*** directory.

### Building for developent

```sh
npm run dev
```

This will build the project and start a dev server.
Point your browser to http://localhost:8080/.

### The Phi3 ONNX Model

The model used in this example is hosted on [Hugging Face](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-onnx-web). It is slightly different than the ONNX model for CUDA or CPU:
1. The model output 'logits' is kept as float32 (even for float16 models) since Javascript does not support float16.
2. Our WebGPU implementation uses the custom Multiheaded Attention operator instread of Group Query Attention.
3. Phi3 is larger then 2GB and we need to use external data files. To keep them cacheable in the browser,
both model.onnx and model.onnx.data are kept under 2GB.

The model was created using the [ONNX genai model builder](https://github.com/microsoft/onnxruntime-genai/tree/main/src/python/py/models).

If you like to create the model yourself, you can use [Olive](https://github.com/microsoft/Olive/).
An example how to create the model for ONNX Runtime Web with Olive can be found [here](https://github.com/microsoft/Olive/tree/main/examples/phi3).
43 changes: 43 additions & 0 deletions js/chat/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
<!doctype html>
<html lang="en">

<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.1/dist/css/bootstrap.min.css" rel="stylesheet"
integrity="sha384-4bw+/aepP/YC94hEpVNVgiZdgIC5+VKNBQNGCHeKRQN+PtmoHDEXuppvnDJzQIu9" crossorigin="anonymous" />
<link rel="stylesheet" href="main.css">

<title>Chat with onnxruntime-web</title>
</head>

<body data-bs-theme="dark">
<div id="root"></div>

<div class="container">
<div class="row pt-3">
<div class="col-md-8 col-12">
<h2>Chat with onnxruntime-web</h2>
</div>
<div id="status">
</div>
</div>
<div id="scroll-wrapper">
<div id="chat-container" class="card">
<div class="card-body">
<div id="chat-history"></div>
</div>
</div>
</div>
</div>
<div class="container p-0 card" id="input-area">
<div class="input-group">
<textarea class="form-control" id="user-input" placeholder="Type your question here ..."></textarea>
<button id="send-button" class="btn btn-primary">Send</button>
</div>
</div>

<script type="module" src="dist/main.js"></script>
</body>

</html>
226 changes: 226 additions & 0 deletions js/chat/llm.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
import * as ort from 'onnxruntime-web/webgpu';

ort.env.wasm.numThreads = 1;
ort.env.wasm.simd = true;
ort.env.wasm.wasmPaths = document.location.pathname.replace('index.html', '') + 'dist/';


function log(i) { console.log(i); document.getElementById('status').innerText += `\n${i}`; }

//
// load file from server or cache
//
async function fetchAndCache(url) {
try {
const cache = await caches.open("onnx");
let cachedResponse = await cache.match(url);
if (cachedResponse === undefined) {
log(`${url} (network)`);
const buffer = await fetch(url).then(response => response.arrayBuffer());
try {
await cache.put(url, new Response(buffer));
} catch (error) {
console.error(error);
}
return buffer;
}
log(`${url} (cached)`);
const data = await cachedResponse.arrayBuffer();
return data;
} catch (error) {
log(`can't fetch ${url}`);
throw error;
}
}

//
// class to handle a large language model on top of onnxruntime-web
//
export class LLM {
sess = undefined;
profiler = false;
feed = {};
output_tokens = [];
eos = 2;
need_position_ids = true;
stop = false;
kv_dims = [];
dtype = "float16";
max_tokens = 9999;

constructor() {
}

async load(model, options) {
const provider = options.provider || "webgpu";
const verbose = options.verbose;
const local = options.local;
const hasFP16 = (provider === "wasm") ? false : options.hasFP16;
this.profiler = options.profiler;

const model_path = (local) ? "models/" + model.path : "https://huggingface.co/" + model.path + "/resolve/main";
let model_file = model.file || "model";
model_file = (hasFP16) ? model_file + "_q4f16.onnx" : model_file + "_q4.onnx";

log(`loading... ${model.name}, ${provider}`);
const json_bytes = await fetchAndCache(model_path + "/config.json");
let textDecoder = new TextDecoder();
const model_config = JSON.parse(textDecoder.decode(json_bytes));

const model_bytes = await fetchAndCache(model_path + "/onnx/" + model_file);
const externaldata = (model.externaldata) ? await fetchAndCache(model_path + "/onnx/" + model_file + '_data') : false;
let modelSize = model_bytes.byteLength;
if (externaldata) {
modelSize += externaldata.byteLength;
}
log(`model size ${Math.round(modelSize / 1024 / 1024)} MB`);

const opt = {
executionProviders: [provider],
preferredOutputLocation: {},
}

switch (provider) {
case "webgpu":
for (let i = 0; i < model_config.num_hidden_layers; ++i) {
opt.preferredOutputLocation[`present.${i}.key`] = 'gpu-buffer';
opt.preferredOutputLocation[`present.${i}.value`] = 'gpu-buffer';
}
break;
}

if (externaldata !== undefined) {
opt.externalData = [
{
data: externaldata,
path: model_file + "_data",
},
]
}
if (verbose) {
opt.logSeverityLevel = 0;
opt.logVerbosityLevel = 0;
ort.env.logLevel = "verbose";
}

ort.env.webgpu.profiling = {}
if (this.profiler) {
opt.enableProfiling = true;
ort.env.webgpu.profilingMode = 'default';
ort.env.webgpu.profiling.mode = 'default';
}

this.sess = await ort.InferenceSession.create(model_bytes, opt);
this.eos = model_config.eos_token_id;
this.kv_dims = [1, model_config.num_key_value_heads, 0, model_config.hidden_size / model_config.num_attention_heads];
this.dtype = (hasFP16) ? "float16" : "float32";
this.num_layers = model_config.num_hidden_layers;
this.initilize_feed();
}

initilize_feed() {
const feed = this.feed;

// dispose of previous gpu buffers
for (const name in feed) {
const t = feed[name];
if (t.location === 'gpu-buffer') {
t.dispose();
}
}
this.feed = {};
// key value cache is zero copy, just pass gpu buffer as referece
const empty = (this.dtype === "float16") ? new Uint16Array() : [];
for (let i = 0; i < this.num_layers; ++i) {
this.feed[`past_key_values.${i}.key`] = new ort.Tensor(this.dtype, empty, this.kv_dims)
this.feed[`past_key_values.${i}.value`] = new ort.Tensor(this.dtype, empty, this.kv_dims)
}
this.output_tokens = [];
}

//
// poor mens argmax
argmax(t) {
const arr = t.data;
const start = t.dims[2] * (t.dims[1] - 1);
let max = arr[start];
let maxidx = 0;

for (let i = 0; i < t.dims[2]; i++) {
const val = arr[i + start];
if (!isFinite(val)) {
throw new Error("found infinitive in logits");
}
if (val > max) {
max = arr[i + start];
maxidx = i;
}
}
return maxidx;
}

//
// update key value cache
//
update_kv_cache(feed, outputs) {
for (const name in outputs) {
if (name.startsWith('present')) {
let newName = name.replace('present', 'past_key_values');
// dispose previous gpu buffers
const t = feed[newName];
if (t.location === 'gpu-buffer') {
t.dispose();
}
feed[newName] = outputs[name];
}
}
}

//
// tell generate to stop()
//
abort() {
this.stop = true;
}

//
// prefill prompt and generate tokens, greedy search only
//
async generate(tokens, callback, options) {
const max_tokens = options.max_tokens || 256;
const feed = this.feed;
const input_ids = new ort.Tensor('int64', BigInt64Array.from(tokens.map(BigInt)), [1, tokens.length]);
feed['input_ids'] = input_ids;
this.stop = false;

this.output_tokens.push(...input_ids.data);

let last_token = 0n;
let seqlen = this.output_tokens.length;
const input_len = input_ids.size;

if (this.need_position_ids) {
feed['position_ids'] = new ort.Tensor('int64', BigInt64Array.from({ length: input_len }, (_, i) => BigInt(seqlen - input_len + i)), [1, input_len]);
}

while (last_token != this.eos && last_token != 32007 && seqlen < max_tokens && !this.stop) {
seqlen = this.output_tokens.length;
feed['attention_mask'] = new ort.Tensor('int64', BigInt64Array.from({ length: seqlen }, () => 1n), [1, seqlen]);
const outputs = await this.sess.run(feed);
last_token = BigInt(this.argmax(outputs.logits));
this.output_tokens.push(last_token);
if (callback && !this.profiler) {
callback(this.output_tokens);
}
this.update_kv_cache(feed, outputs);
feed['input_ids'] = new ort.Tensor('int64', BigInt64Array.from([last_token]), [1, 1]);
if (this.need_position_ids) {
feed['position_ids'] = new ort.Tensor('int64', BigInt64Array.from([BigInt(seqlen)]), [1, 1]);
}
}
if (this.profiler) {
this.sess.endProfiling();
}
return this.output_tokens;
}
}
57 changes: 57 additions & 0 deletions js/chat/main.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
body {
color: #f5f5f5;
font-family: 'Arial', sans-serif;
}

.user-message {
background-color: rgb(86, 144, 163);
color: white;
padding: 10px;
border-radius: 10px;
white-space: pre-wrap;
width: fit-content;
}

.response-message {
background-color: rgb(62, 62, 62);
color: white;
padding: 10px;
border-radius: 10px;
padding-right: 20px;
position: relative;
margin-right: auto;
}

.response-message p {
margin-right: 40px;
}

#chat-container {
display: none;
margin: 0 auto;
overflow: auto;
}

#chat-history {
display: flex;
flex-direction: column;
}

.copy-button {
position: absolute;
bottom: 5px;
right: 5px;
margin: 0 5px 5px 0;
}

#scroll-wrapper {
padding-bottom: 5.5rem;
}

#input-area {
position: fixed;
bottom: 0;
margin-bottom: 5px;
left: 50%;
transform: translateX(-50%);
}
Loading