diff --git a/package-lock.json b/package-lock.json index fd48d6bb5..24facb3e6 100644 --- a/package-lock.json +++ b/package-lock.json @@ -9,8 +9,13 @@ "version": "2.15.1", "license": "Apache-2.0", "dependencies": { +<<<<<<< HEAD + "@huggingface/jinja": "^0.1.0", + "onnxruntime-web": "1.17.0", +======= "@huggingface/jinja": "^0.2.1", "onnxruntime-web": "1.14.0", +>>>>>>> main "sharp": "^0.32.0" }, "devDependencies": { @@ -27,7 +32,7 @@ "webpack-dev-server": "^4.13.3" }, "optionalDependencies": { - "onnxruntime-node": "1.14.0" + "onnxruntime-node": "1.17.0" } }, "node_modules/@ampproject/remapping": { @@ -5322,9 +5327,9 @@ "dev": true }, "node_modules/long": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/long/-/long-4.0.0.tgz", - "integrity": "sha512-XsP+KhQif4bjX1kbuSiySJFNAehNxgLb6hPRGJ9QsUr8ajHkuXGdrHmFUTUUXhDwVX2R5bY4JNZEwbUiMhV+MA==" + "version": "5.2.3", + "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", + "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==" }, "node_modules/lru-cache": { "version": "6.0.0", @@ -5748,23 +5753,15 @@ "url": "https://github.com/sponsors/sindresorhus" } }, - "node_modules/onnx-proto": { - "version": "4.0.4", - "resolved": "https://registry.npmjs.org/onnx-proto/-/onnx-proto-4.0.4.tgz", - "integrity": "sha512-aldMOB3HRoo6q/phyB6QRQxSt895HNNw82BNyZ2CMh4bjeKv7g/c+VpAFtJuEMVfYLMbRx61hbuqnKceLeDcDA==", - "dependencies": { - "protobufjs": "^6.8.8" - } - }, "node_modules/onnxruntime-common": { - "version": "1.14.0", - "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.14.0.tgz", - "integrity": "sha512-3LJpegM2iMNRX2wUmtYfeX/ytfOzNwAWKSq1HbRrKc9+uqG/FsEA0bbKZl1btQeZaXhC26l44NWpNUeXPII7Ew==" + "version": "1.17.0", + "resolved": "https://registry.npmjs.org/onnxruntime-common/-/onnxruntime-common-1.17.0.tgz", + "integrity": "sha512-Vq1remJbCPITjDMJ04DA7AklUTnbYUp4vbnm6iL7ukSt+7VErH0NGYfekRSTjxxurEtX7w41PFfnQlE6msjPJw==" }, "node_modules/onnxruntime-node": { - "version": "1.14.0", - "resolved": "https://registry.npmjs.org/onnxruntime-node/-/onnxruntime-node-1.14.0.tgz", - "integrity": "sha512-5ba7TWomIV/9b6NH/1x/8QEeowsb+jBEvFzU6z0T4mNsFwdPqXeFUM7uxC6QeSRkEbWu3qEB0VMjrvzN/0S9+w==", + "version": "1.17.0", + "resolved": "https://registry.npmjs.org/onnxruntime-node/-/onnxruntime-node-1.17.0.tgz", + "integrity": "sha512-pRxdqSP3a6wtiFVkVX1V3/gsEMwBRUA9D2oYmcN3cjF+j+ILS+SIY2L7KxdWapsG6z64i5rUn8ijFZdIvbojBg==", "optional": true, "os": [ "win32", @@ -5772,20 +5769,20 @@ "linux" ], "dependencies": { - "onnxruntime-common": "~1.14.0" + "onnxruntime-common": "1.17.0" } }, "node_modules/onnxruntime-web": { - "version": "1.14.0", - "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.14.0.tgz", - "integrity": "sha512-Kcqf43UMfW8mCydVGcX9OMXI2VN17c0p6XvR7IPSZzBf/6lteBzXHvcEVWDPmCKuGombl997HgLqj91F11DzXw==", + "version": "1.17.0", + "resolved": "https://registry.npmjs.org/onnxruntime-web/-/onnxruntime-web-1.17.0.tgz", + "integrity": "sha512-O5IZrnJ4ABMmgttdcuG/y3z8WT0zMieCeh/4Eq3lf3CeLwKLoPno38WbAvDiRUkfKjXUyu2mw532YIuGi61YJA==", "dependencies": { "flatbuffers": "^1.12.0", "guid-typescript": "^1.0.9", - "long": "^4.0.0", - "onnx-proto": "^4.0.4", - "onnxruntime-common": "~1.14.0", - "platform": "^1.3.6" + "long": "^5.2.3", + "onnxruntime-common": "1.17.0", + "platform": "^1.3.6", + "protobufjs": "^7.2.4" } }, "node_modules/open": { @@ -6044,9 +6041,9 @@ } }, "node_modules/protobufjs": { - "version": "7.2.4", - "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.4.tgz", - "integrity": "sha512-AT+RJgD2sH8phPmCf7OUZR8xGdcJRga4+1cOaXJ64hvcSkVhNcRHOwIxUatPH15+nj59WAGTDv3LSGZPEQbJaQ==", + "version": "7.2.6", + "resolved": "https://registry.npmjs.org/protobufjs/-/protobufjs-7.2.6.tgz", + "integrity": "sha512-dgJaEDDL6x8ASUZ1YqWciTRrdOuYNzoOf27oHNfdyvKqHr5i0FV7FSLU+aIeFjyFgVxrpTOtQUi0BLLBymZaBw==", "hasInstallScript": true, "dependencies": { "@protobufjs/aspromise": "^1.1.2", @@ -6066,11 +6063,6 @@ "node": ">=12.0.0" } }, - "node_modules/protobufjs/node_modules/long": { - "version": "5.2.3", - "resolved": "https://registry.npmjs.org/long/-/long-5.2.3.tgz", - "integrity": "sha512-lcHwpNoggQTObv5apGNCTdJrO69eHOZMi4BNC+rTLER8iHAqGrUVeLh/irVIM7zTw2bOXA8T6uNPeujwOLg/2Q==" - }, "node_modules/proxy-addr": { "version": "2.0.7", "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.7.tgz", diff --git a/package.json b/package.json index a7a22e2ce..d23de9971 100644 --- a/package.json +++ b/package.json @@ -38,12 +38,12 @@ }, "homepage": "https://github.com/xenova/transformers.js#readme", "dependencies": { - "onnxruntime-web": "1.14.0", + "onnxruntime-web": "1.17.1", "sharp": "^0.32.0", "@huggingface/jinja": "^0.2.1" }, "optionalDependencies": { - "onnxruntime-node": "1.14.0" + "onnxruntime-node": "1.17.1" }, "devDependencies": { "@types/jest": "^29.5.1", diff --git a/src/models.js b/src/models.js index ef3d2806c..e2441a9bf 100644 --- a/src/models.js +++ b/src/models.js @@ -203,9 +203,16 @@ function validateInputs(session, inputs) { async function sessionRun(session, inputs) { const checkedInputs = validateInputs(session, inputs); try { - // @ts-ignore - let output = await session.run(checkedInputs); + // pass the original ort tensor + const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor])); + let output = await session.run(ortFeed); output = replaceTensors(output); + for (const [name, t] of Object.entries(checkedInputs)) { + // if we use gpu buffers for kv_caches, we own them and need to dispose() + if (name.startsWith('past_key_values')) { + t.dispose(); + }; + } return output; } catch (e) { // This usually occurs when the inputs are of the wrong type. diff --git a/src/utils/tensor.js b/src/utils/tensor.js index 819c2dbb6..861d2d0d8 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -17,6 +17,7 @@ import { const DataTypeMap = Object.freeze({ float32: Float32Array, + float16: Uint16Array, float64: Float64Array, string: Array, // string[] int8: Int8Array, @@ -39,16 +40,32 @@ const ONNXTensor = ONNX.Tensor; export class Tensor { /** @type {number[]} Dimensions of the tensor. */ - dims; + get dims() { + // @ts-ignore + return this.ort_tensor.dims; + } + set dims(value) { + // FIXME: ONNXTensor declares dims as readonly so one needs to use the constructor() if dims change. + // @ts-ignore + this.ort_tensor.dims = value; + } /** @type {DataType} Type of the tensor. */ - type; + get type() { + return this.ort_tensor.type; + }; /** @type {DataArray} The data stored in the tensor. */ - data; + get data() { + return this.ort_tensor.data; + } /** @type {number} The number of elements in the tensor. */ - size; + get size() { + return this.ort_tensor.size; + }; + + ort_tensor; /** * Create a new Tensor or copy an existing Tensor. @@ -56,16 +73,15 @@ export class Tensor { */ constructor(...args) { if (args[0] instanceof ONNXTensor) { - // Create shallow copy - Object.assign(this, args[0]); - + this.ort_tensor = args[0]; } else { // Create new tensor - Object.assign(this, new ONNXTensor( + const t = new ONNXTensor( /** @type {DataType} */(args[0]), /** @type {Exclude} */(args[1]), args[2] - )); + ); + this.ort_tensor = t; } return new Proxy(this, { @@ -89,6 +105,11 @@ export class Tensor { }); } + dispose() { + this.ort_tensor.dispose(); + // this.ort_tensor = undefined; + } + /** * Returns an iterator object for iterating over the tensor data in row-major order. * If the tensor has more than one dimension, the iterator will yield subarrays.