Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
665 changes: 352 additions & 313 deletions package-lock.json

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
"license": "Apache-2.0",
"homepage": "https://github.com/mlc-ai/web-llm",
"devDependencies": {
"@mlc-ai/web-runtime": "^0.23.0-dev0",
"@mlc-ai/web-tokenizers": "^0.1.6",
"@mlc-ai/web-xgrammar": "0.1.0",
"@next/eslint-plugin-next": "^14.2.3",
"@rollup/plugin-commonjs": "^20.0.0",
"@rollup/plugin-node-resolve": "^13.0.4",
Expand All @@ -50,8 +52,6 @@
"rollup-plugin-typescript2": "^0.34.1",
"ts-jest": "^29.1.2",
"tslib": "^2.3.1",
"@mlc-ai/web-runtime": "0.18.0-dev2",
"@mlc-ai/web-xgrammar": "0.1.0",
"typescript": "^4.9.5"
},
"dependencies": {
Expand Down
6 changes: 3 additions & 3 deletions src/cache_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export async function hasModelInCache(
const modelRecord = findModelRecord(modelId, appConfig);
const modelUrl = cleanModelUrl(modelRecord.model);
const cacheType = appConfig.useIndexedDBCache ? "indexeddb" : "cache";
return tvmjs.hasNDArrayInCache(modelUrl, "webllm/model", cacheType);
return tvmjs.hasTensorInCache(modelUrl, "webllm/model", cacheType);
}

export async function deleteModelAllInfoInCache(
Expand Down Expand Up @@ -60,10 +60,10 @@ export async function deleteModelInCache(
const modelUrl = cleanModelUrl(modelRecord.model);
let modelCache: tvmjs.ArtifactCacheTemplate;
if (appConfig.useIndexedDBCache) {
tvmjs.deleteNDArrayCache(modelUrl, "webllm/model", "indexeddb");
tvmjs.deleteTensorCache(modelUrl, "webllm/model", "indexeddb");
modelCache = new tvmjs.ArtifactIndexedDBCache("webllm/model");
} else {
tvmjs.deleteNDArrayCache(modelUrl, "webllm/model", "cache");
tvmjs.deleteTensorCache(modelUrl, "webllm/model", "cache");
modelCache = new tvmjs.ArtifactCache("webllm/model");
}
await modelCache.deleteInCache(new URL("tokenizer.model", modelUrl).href);
Expand Down
4 changes: 2 additions & 2 deletions src/embedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ export class EmbeddingPipeline {
maskNDArray = maskNDArray.view([curBatchSize, maxInputSize]);

// 3.5 Actual forwarding on GPU, logits of shape (curBatchSize, maxInputSize, hidden_size)
const logitsCurBatchOnGPU: tvmjs.NDArray = this.prefill(
const logitsCurBatchOnGPU: tvmjs.Tensor = this.prefill(
inputNDArray,
maskNDArray,
this.params,
Expand All @@ -213,7 +213,7 @@ export class EmbeddingPipeline {

// 3.6 Copy logits to CPU, flatten to curBatchSize * maxInputSize * hidden_size
const hidden_size = logitsCurBatchOnGPU.shape[2];
let logitsCurBatchOnCPU: tvmjs.NDArray = this.tvm.empty(
let logitsCurBatchOnCPU: tvmjs.Tensor = this.tvm.empty(
logitsCurBatchOnGPU.shape,
logitsCurBatchOnGPU.dtype,
this.tvm.cpu(),
Expand Down
2 changes: 1 addition & 1 deletion src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ export class MLCEngine implements MLCEngineInterface {
this.logger,
);
const cacheType = this.appConfig.useIndexedDBCache ? "indexeddb" : "cache";
await tvm.fetchNDArrayCache(
await tvm.fetchTensorCache(
modelUrl,
tvm.webgpu(),
"webllm/model",
Expand Down
72 changes: 36 additions & 36 deletions src/llm_chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ export class LLMChatPipeline {
// parameter states
private params: tvmjs.TVMObject;
private kvCache: tvmjs.TVMObject;
private logitsOnCPU?: tvmjs.NDArray = undefined;
private logitsOnCPU?: tvmjs.Tensor = undefined;
private filledKVCacheLength = 0;

// meta data
Expand Down Expand Up @@ -224,7 +224,7 @@ export class LLMChatPipeline {
// 2. Get json stored in the vm's metadata function
const fgetMetadata = this.vm.getFunction("_metadata");
const ret_value = fgetMetadata();
const metadataStr = this.tvm.detachFromCurrentScope(ret_value).toString();
const metadataStr = ret_value.toString();
const metadata = JSON.parse(metadataStr);

// 3. Load parameters by name
Expand Down Expand Up @@ -671,7 +671,7 @@ export class LLMChatPipeline {

// 2. Prefill each chunk
this.tvm.beginScope();
let logits: tvmjs.NDArray;
let logits: tvmjs.Tensor;
for (let i = 0; i < chunks.length; i++) {
const chunk = chunks[i];
const chunkLen = chunkLens[i];
Expand Down Expand Up @@ -860,7 +860,7 @@ export class LLMChatPipeline {
* @note precondition: inputTokens.length <= prefillChunkSize, since we take care of
* chunking in `getChunkedPrefillInputData()`.
*/
private getTokensEmbeddings(inputTokens: number[]): tvmjs.NDArray {
private getTokensEmbeddings(inputTokens: number[]): tvmjs.Tensor {
this.tvm.beginScope();
if (inputTokens.length > this.prefillChunkSize) {
throw new Error(
Expand All @@ -873,7 +873,7 @@ export class LLMChatPipeline {
this.device,
);
inputData.copyFrom(inputTokens);
const embed: tvmjs.NDArray = this.tvm.detachFromCurrentScope(
const embed: tvmjs.Tensor = this.tvm.detachFromCurrentScope(
this.embed!(inputData, this.params),
);
this.tvm.endScope();
Expand All @@ -886,9 +886,9 @@ export class LLMChatPipeline {
*/
private async getImageEmbeddings(
inputImage: ImageURL,
): Promise<tvmjs.NDArray> {
): Promise<tvmjs.Tensor> {
this.tvm.beginScope();
// 1. Transform ImageURL into image input in NDArray
// 1. Transform ImageURL into image input in TVMArray
const url = inputImage.url;
// url starting with `data:image` and `http` share the same loading method
const imgData: ImageData = await getImageDataFromURL(url);
Expand All @@ -900,7 +900,7 @@ export class LLMChatPipeline {
.view([1, imgData.height, imgData.width, 3]); // NHWC

// 2. Call image embed kernel
const embed: tvmjs.NDArray = this.tvm.detachFromCurrentScope(
const embed: tvmjs.Tensor = this.tvm.detachFromCurrentScope(
this.image_embed!(pixelArray, this.params),
);
if (embed.shape[0] !== IMAGE_EMBED_SIZE) {
Expand All @@ -920,14 +920,14 @@ export class LLMChatPipeline {
*
* @param inputData data to embed and forward
* @param inputDataLen length of this inputData, should smaller than prefill chunk size.
* @returns The logits returned by this forward as tvmjs.NDArray on GPU.
* @returns The logits returned by this forward as tvmjs.Tensor on GPU.
*
* @note Precondition: inputData's data length is smaller than prefill chunk size
*/
private async embedAndForward(
inputData: Array<Array<number> | ImageURL>,
inputDataLen: number,
): Promise<tvmjs.NDArray> {
): Promise<tvmjs.Tensor> {
if (inputDataLen > this.prefillChunkSize) {
throw new Error(
"InternalError: expect inputDataLen <= this.prefillChunkSize.",
Expand All @@ -938,18 +938,18 @@ export class LLMChatPipeline {

// 1. Embed all inputData
this.tvm.beginScope();
const embeddings: tvmjs.NDArray[] = [];
const embeddings: tvmjs.Tensor[] = [];
for (let i = 0; i < inputData.length; i++) {
const data = inputData[i];
if (Array.isArray(data)) {
embeddings.push(this.getTokensEmbeddings(data));
embeddings.push(await this.getTokensEmbeddings(data));
} else {
embeddings.push(await this.getImageEmbeddings(data));
}
}

// 2. Concatenate embeddings
let allEmbeddings: tvmjs.NDArray;
let allEmbeddings: tvmjs.Tensor;
if (embeddings.length === 1) {
allEmbeddings = embeddings[0];
} else {
Expand Down Expand Up @@ -983,7 +983,7 @@ export class LLMChatPipeline {
}

// NOTE: caller must call device.sync()
private updateLogitsOnCPU(logits: tvmjs.NDArray): tvmjs.NDArray {
private updateLogitsOnCPU(logits: tvmjs.Tensor): tvmjs.Tensor {
if (this.logitsOnCPU == undefined) {
this.logitsOnCPU = this.tvm.detachFromCurrentScope(
this.tvm.empty(logits.shape, logits.dtype, this.tvm.cpu()),
Expand All @@ -998,7 +998,7 @@ export class LLMChatPipeline {
}

private async sampleTokenFromLogits(
logitsOnGPU: tvmjs.NDArray,
logitsOnGPU: tvmjs.Tensor,
genConfig?: GenerationConfig,
) {
// 0. Get value of temperature, top_p, and various penalties, possibly overridden by genConfig
Expand Down Expand Up @@ -1160,7 +1160,7 @@ export class LLMChatPipeline {
const logitBiasBegin = performance.now();

const numTokens = Object.keys(logit_bias ?? {}).length;
const pos2seq_id = new Int32Array(numTokens).fill(0);
const pos2seqIds = new Int32Array(numTokens).fill(0);
const tokenIds = new Int32Array(numTokens);
const tokenLogitBias = new Float32Array(numTokens);

Expand All @@ -1173,23 +1173,23 @@ export class LLMChatPipeline {

this.tvm.beginScope();

const pos2seqIdsArray = this.tvm
const pos2seqIdsDevice = this.tvm
.empty([numTokens], "int32", this.device)
.copyFrom(pos2seq_id);
.copyFrom(pos2seqIds);

const tokenIdsArray = this.tvm
const tokenIdsDevice = this.tvm
.empty([numTokens], "int32", this.device)
.copyFrom(tokenIds);

const tokenLogitBiasArray = this.tvm
const tokenLogitBiasDevice = this.tvm
.empty([numTokens], "float32", this.device)
.copyFrom(tokenLogitBias);

this.fapplyLogitBias(
logitsOnGPU.view([1, this.fullVocabSize]),
pos2seqIdsArray,
tokenIdsArray,
tokenLogitBiasArray,
pos2seqIdsDevice,
tokenIdsDevice,
tokenLogitBiasDevice,
);

this.tvm.endScope();
Expand All @@ -1215,7 +1215,7 @@ export class LLMChatPipeline {
if (numTokens > 0) {
const penaltyBegin = performance.now();

const pos2seq_id = new Int32Array(numTokens).fill(0);
const pos2seqIds = new Int32Array(numTokens).fill(0);
const tokenIds = new Int32Array(numTokens).fill(0);
const tokenCnt = new Int32Array(numTokens).fill(0);
const penalties = new Float32Array([
Expand All @@ -1232,29 +1232,29 @@ export class LLMChatPipeline {
.empty([1], "int32", this.device)
.copyFrom([0]);

const pos2seqIdsArray = this.tvm
const pos2seqIdsDevice = this.tvm
.empty([numTokens], "int32", this.device)
.copyFrom(pos2seq_id);
.copyFrom(pos2seqIds);

const tokenIdsArray = this.tvm
const tokenIdsDevice = this.tvm
.empty([numTokens], "int32", this.device)
.copyFrom(tokenIds);

const tokenCntArray = this.tvm
const tokenCntDevice = this.tvm
.empty([numTokens], "int32", this.device)
.copyFrom(tokenCnt);

const penaltiesArray = this.tvm
const penaltiesDevice = this.tvm
.empty([1, 3], "float32", this.device)
.copyFrom(penalties);

this.fapplyPenalty(
logitsOnGPU.view([1, this.fullVocabSize]),
seqIdsArray,
pos2seqIdsArray,
tokenIdsArray,
tokenCntArray,
penaltiesArray,
pos2seqIdsDevice,
tokenIdsDevice,
tokenCntDevice,
penaltiesDevice,
);

this.tvm.endScope();
Expand All @@ -1280,13 +1280,13 @@ export class LLMChatPipeline {
const temperatures = new Float32Array([temperature]);

this.tvm.beginScope();
const temperaturesArray = this.tvm
const temperaturesDevice = this.tvm
.empty([numSeqs], "float32", this.device)
.copyFrom(temperatures);

const probs = this.fsoftmaxWithTemperature(
logitsOnGPU.view([numSeqs, 1, this.fullVocabSize]),
temperaturesArray,
temperaturesDevice,
);
this.updateLogitsOnCPU(probs);
this.tvm.endScope();
Expand Down Expand Up @@ -1458,7 +1458,7 @@ export class LLMChatPipeline {
const chunkLens: Array<number> = retGetChunks[1];

// 2. Prefill each chunk
let logitsOnGPU: tvmjs.NDArray;
let logitsOnGPU: tvmjs.Tensor;
for (let i = 0; i < chunks.length; i++) {
const chunk = chunks[i];
const chunkLen = chunkLens[i];
Expand Down
2 changes: 1 addition & 1 deletion tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"strict": true,
"moduleResolution": "Node",
"esModuleInterop": true,
"lib": ["dom", "WebWorker"]
"lib": ["dom", "WebWorker", "es2022"]
},
"typeRoots": ["./node_modules/@webgpu/types", "./node_modules/@types"],
"include": ["src"],
Expand Down