From 8dab860c1335344965ad706709020044472c0772 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Wed, 14 Jun 2023 17:40:14 +0200 Subject: [PATCH 01/22] Override `LOAD_FUNCTION` for decoder-only models --- src/models.js | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/models.js b/src/models.js index 05410e5d3..10dd92fad 100644 --- a/src/models.js +++ b/src/models.js @@ -578,6 +578,8 @@ function textgenUpdatebeam(beam, newTokenId) { * @extends Callable */ export class PreTrainedModel extends Callable { + static LOAD_FUNCTION = loadAutoModel; + /** * Creates a new instance of the `PreTrainedModel` class. * @param {Object} config The model configuration. @@ -624,7 +626,7 @@ export class PreTrainedModel extends Callable { local_files_only = false, revision = 'main', } = {}) { - let info = await loadAutoModel(pretrained_model_name_or_path, { + let info = await this.LOAD_FUNCTION(pretrained_model_name_or_path, { quantized, progress_callback, config, @@ -2034,6 +2036,8 @@ export class GPT2Model extends GPT2PreTrainedModel { * @extends GPT2PreTrainedModel */ export class GPT2LMHeadModel extends GPT2PreTrainedModel { + static LOAD_FUNCTION = decoderLoadModel; + /** * Creates a new instance of the `GPT2LMHeadModel` class. * @param {Object} config The configuration of the model. @@ -2109,6 +2113,8 @@ export class GPTNeoModel extends GPTNeoPreTrainedModel { } export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { + static LOAD_FUNCTION = decoderLoadModel; + /** * Creates a new instance of the `GPTNeoForCausalLM` class. * @param {Object} config The configuration of the model. @@ -2194,6 +2200,8 @@ export class CodeGenModel extends CodeGenPreTrainedModel { * @extends CodeGenPreTrainedModel */ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { + static LOAD_FUNCTION = decoderLoadModel; + /** * Creates a new instance of the `CodeGenForCausalLM` class. * @param {Object} config The model configuration object. From b0c288a904cb7aa5fdd47bf9405114042cdfa86f Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Wed, 14 Jun 2023 19:17:32 +0200 Subject: [PATCH 02/22] Use object destructuring in `_call` functions --- src/models.js | 100 +++++++++++++++++++++++++------------------------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/src/models.js b/src/models.js index 10dd92fad..48d9254af 100644 --- a/src/models.js +++ b/src/models.js @@ -1066,8 +1066,8 @@ export class BertForMaskedLM extends BertPreTrainedModel { * @returns {Promise} An object containing the model's output logits for masked language modeling. */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new MaskedLMOutput(logits) + let { logits } = await super._call(model_inputs); + return new MaskedLMOutput(logits); } } @@ -1083,8 +1083,8 @@ export class BertForSequenceClassification extends BertPreTrainedModel { * @returns {Promise} An object containing the model's output logits for sequence classification. */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new SequenceClassifierOutput(logits) + let { logits } = await super._call(model_inputs); + return new SequenceClassifierOutput(logits); } } @@ -1100,8 +1100,8 @@ export class BertForTokenClassification extends BertPreTrainedModel { * @returns {Promise} An object containing the model's output logits for token classification. */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new TokenClassifierOutput(logits) + let { logits } = await super._call(model_inputs); + return new TokenClassifierOutput(logits); } } @@ -1117,8 +1117,8 @@ export class BertForQuestionAnswering extends BertPreTrainedModel { * @returns {Promise} An object containing the model's output logits for question answering. */ async _call(model_inputs) { - let outputs = await super._call(model_inputs); - return new QuestionAnsweringModelOutput(outputs.start_logits, outputs.end_logits); + let { start_logits, end_logits } = await super._call(model_inputs); + return new QuestionAnsweringModelOutput(start_logits, end_logits); } } ////////////////////////////////////////////////// @@ -1140,8 +1140,8 @@ export class DistilBertForSequenceClassification extends DistilBertPreTrainedMod * @returns {Promise} An object containing the model's output logits for sequence classification. */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new SequenceClassifierOutput(logits) + let { logits } = await super._call(model_inputs); + return new SequenceClassifierOutput(logits); } } @@ -1157,8 +1157,8 @@ export class DistilBertForTokenClassification extends DistilBertPreTrainedModel * @returns {Promise} An object containing the model's output logits for token classification. */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new TokenClassifierOutput(logits) + let { logits } = await super._call(model_inputs); + return new TokenClassifierOutput(logits); } } @@ -1175,8 +1175,8 @@ export class DistilBertForQuestionAnswering extends DistilBertPreTrainedModel { * @returns {Promise} An object containing the model's output logits for question answering. */ async _call(model_inputs) { - let outputs = await super._call(model_inputs); - return new QuestionAnsweringModelOutput(outputs.start_logits, outputs.end_logits); + let { start_logits, end_logits } = await super._call(model_inputs); + return new QuestionAnsweringModelOutput(start_logits, end_logits); } } @@ -1192,8 +1192,8 @@ export class DistilBertForMaskedLM extends DistilBertPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new MaskedLMOutput(logits) + let { logits } = await super._call(model_inputs); + return new MaskedLMOutput(logits); } } ////////////////////////////////////////////////// @@ -1216,8 +1216,8 @@ export class MobileBertForMaskedLM extends MobileBertPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new MaskedLMOutput(logits) + let { logits } = await super._call(model_inputs); + return new MaskedLMOutput(logits); } } @@ -1232,8 +1232,8 @@ export class MobileBertForSequenceClassification extends MobileBertPreTrainedMod * @returns {Promise} returned object */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new SequenceClassifierOutput(logits) + let { logits } = await super._call(model_inputs); + return new SequenceClassifierOutput(logits); } } @@ -1248,8 +1248,8 @@ export class MobileBertForQuestionAnswering extends MobileBertPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let outputs = await super._call(model_inputs); - return new QuestionAnsweringModelOutput(outputs.start_logits, outputs.end_logits); + let { start_logits, end_logits } = await super._call(model_inputs); + return new QuestionAnsweringModelOutput(start_logits, end_logits); } } ////////////////////////////////////////////////// @@ -1267,8 +1267,8 @@ export class SqueezeBertForMaskedLM extends SqueezeBertPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new MaskedLMOutput(logits) + let { logits } = await super._call(model_inputs); + return new MaskedLMOutput(logits); } } export class SqueezeBertForSequenceClassification extends SqueezeBertPreTrainedModel { @@ -1279,8 +1279,8 @@ export class SqueezeBertForSequenceClassification extends SqueezeBertPreTrainedM * @returns {Promise} returned object */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new SequenceClassifierOutput(logits) + let { logits } = await super._call(model_inputs); + return new SequenceClassifierOutput(logits); } } export class SqueezeBertForQuestionAnswering extends SqueezeBertPreTrainedModel { @@ -1291,8 +1291,8 @@ export class SqueezeBertForQuestionAnswering extends SqueezeBertPreTrainedModel * @returns {Promise} returned object */ async _call(model_inputs) { - let outputs = await super._call(model_inputs); - return new QuestionAnsweringModelOutput(outputs.start_logits, outputs.end_logits); + let { start_logits, end_logits } = await super._call(model_inputs); + return new QuestionAnsweringModelOutput(start_logits, end_logits); } } ////////////////////////////////////////////////// @@ -1310,8 +1310,8 @@ export class AlbertForSequenceClassification extends AlbertPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new SequenceClassifierOutput(logits) + let { logits } = await super._call(model_inputs); + return new SequenceClassifierOutput(logits); } } export class AlbertForQuestionAnswering extends AlbertPreTrainedModel { @@ -1322,8 +1322,8 @@ export class AlbertForQuestionAnswering extends AlbertPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let outputs = await super._call(model_inputs); - return new QuestionAnsweringModelOutput(outputs.start_logits, outputs.end_logits); + let { start_logits, end_logits } = await super._call(model_inputs); + return new QuestionAnsweringModelOutput(start_logits, end_logits); } } export class AlbertForMaskedLM extends AlbertPreTrainedModel { @@ -1334,8 +1334,8 @@ export class AlbertForMaskedLM extends AlbertPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new MaskedLMOutput(logits) + let { logits } = await super._call(model_inputs); + return new MaskedLMOutput(logits); } } ////////////////////////////////////////////////// @@ -1690,8 +1690,8 @@ export class BartForSequenceClassification extends BartPretrainedModel { * @returns {Promise} An object containing the model's output logits for sequence classification. */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new SequenceClassifierOutput(logits) + let { logits } = await super._call(model_inputs); + return new SequenceClassifierOutput(logits); } } @@ -1714,8 +1714,8 @@ export class RobertaForMaskedLM extends RobertaPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new MaskedLMOutput(logits) + let { logits } = await super._call(model_inputs); + return new MaskedLMOutput(logits); } } @@ -1731,8 +1731,8 @@ export class RobertaForSequenceClassification extends RobertaPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new SequenceClassifierOutput(logits) + let { logits } = await super._call(model_inputs); + return new SequenceClassifierOutput(logits); } } @@ -1748,8 +1748,8 @@ export class RobertaForQuestionAnswering extends RobertaPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let outputs = await super._call(model_inputs); - return new QuestionAnsweringModelOutput(outputs.start_logits, outputs.end_logits); + let { start_logits, end_logits } = await super._call(model_inputs); + return new QuestionAnsweringModelOutput(start_logits, end_logits); } } ////////////////////////////////////////////////// @@ -2266,8 +2266,8 @@ export class ViTForImageClassification extends ViTPreTrainedModel { * @param {any} model_inputs */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new SequenceClassifierOutput(logits) + let { logits } = await super._call(model_inputs); + return new SequenceClassifierOutput(logits); } } ////////////////////////////////////////////////// @@ -2279,8 +2279,8 @@ export class DetrForObjectDetection extends DetrPreTrainedModel { * @param {any} model_inputs */ async _call(model_inputs) { - let output = (await super._call(model_inputs)); - return new DetrObjectDetectionOutput(output.logits, output.pred_boxes) + let { logits, pred_boxes } = await super._call(model_inputs); + return new DetrObjectDetectionOutput(logits, pred_boxes); } } @@ -2291,8 +2291,8 @@ export class DetrForSegmentation extends DetrPreTrainedModel { * @returns {Promise} Object containing segmentation outputs */ async _call(model_inputs) { - let output = (await super._call(model_inputs)); - return new DetrSegmentationOutput(output.logits, output.pred_boxes, output.pred_masks); + let { logits, pred_boxes, pred_masks } = await super._call(model_inputs); + return new DetrSegmentationOutput(logits, pred_boxes, pred_masks); } } @@ -2336,8 +2336,8 @@ export class SamModel extends SamPreTrainedModel { */ async _call(model_inputs) { // TODO split into encoder and decoder - let output = (await super._call(model_inputs)); - return new SamImageSegmentationOutput(output.iou_scores, output.pred_masks); + let { iou_scores, pred_masks } = await super._call(model_inputs); + return new SamImageSegmentationOutput(iou_scores, pred_masks); } } From f520c3751d48f79fe7845d49210a2bac9bc5a4ed Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Wed, 14 Jun 2023 20:12:17 +0200 Subject: [PATCH 03/22] Allow decoder-only models to be called --- src/models.js | 50 +++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 43 insertions(+), 7 deletions(-) diff --git a/src/models.js b/src/models.js index 48d9254af..5b0fc9b1d 100644 --- a/src/models.js +++ b/src/models.js @@ -462,7 +462,7 @@ async function seq2seqRunBeam(self, beam, { * Forward pass of the text generation model. * @param {Object} self The text generation model object. * @param {Object} model_inputs The input data to be used for the forward pass. - * @returns {Promise} Promise that resolves with an object containing the logits and past key values. + * @returns {Promise} Promise that resolves with an object containing the logits and past key values. */ async function textgen_forward(self, model_inputs) { let past_key_values = model_inputs.past_key_values; @@ -477,7 +477,7 @@ async function textgen_forward(self, model_inputs) { let logits = decoderResults.logits; past_key_values = self.getPastKeyValues(decoderResults, past_key_values); - return { logits, past_key_values }; + return new CausalLMOutputWithPast(logits, past_key_values); } /** @@ -639,13 +639,33 @@ export class PreTrainedModel extends Callable { return new this(...info); } + /** + * Whether this model can generate sequences with `.generate()`. + * @returns {boolean} `true` if this model can generate sequences, `false` otherwise. + */ + can_generate() { + if (!('generate' in this)) return false; + + if (this.config.is_encoder_decoder) { + return 'decoder_merged_session' in this; + } else { + return 'session' in this; + } + } /** * Runs the model with the provided inputs * @param {Object} model_inputs Object containing input tensors * @returns {Promise} Object containing output tensors */ async _call(model_inputs) { - return await sessionRun(this.session, model_inputs); + // TODO: prepare inputs + if (this.config.is_encoder_decoder || !('generate' in this)) { + // Either encoder-decoder or encoder-only model. + return await sessionRun(this.session, model_inputs); + } else { + // Decoder-only model. + return await textgen_forward(this, model_inputs); + } } /** @@ -1010,7 +1030,9 @@ export class PreTrainedModel extends Callable { * @param {boolean} [hasDecoder=false] Whether the model has a decoder. */ addPastKeyValues(decoderFeeds, pastKeyValues, hasDecoder = false) { - if (pastKeyValues === null) { + if (pastKeyValues) { + Object.assign(decoderFeeds, pastKeyValues) + } else { // TODO support batches (i.e., batch_size > 1) if (hasDecoder) { // @ts-ignore @@ -1038,9 +1060,6 @@ export class PreTrainedModel extends Callable { decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims) } } - - } else { - Object.assign(decoderFeeds, pastKeyValues) } } } @@ -2916,3 +2935,20 @@ export class QuestionAnsweringModelOutput extends ModelOutput { this.end_logits = end_logits; } } + + +/** + * Base class for causal language model (or autoregressive) outputs. + */ +export class CausalLMOutputWithPast extends ModelOutput { + /** + * @param {Tensor} logits Prediction scores of the language modeling head (scores for each vocabulary token before softmax). + * @param {Tensor} past_key_values Contains pre-computed hidden-states (key and values in the self-attention blocks) + * that can be used (see `past_key_values` input) to speed up sequential decoding. + */ + constructor(logits, past_key_values) { + super(); + this.logits = logits; + this.past_key_values = past_key_values; + } +} \ No newline at end of file From d7a8ec30987acd99685586d6e885f10a9508b8f1 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Wed, 14 Jun 2023 23:54:39 +0200 Subject: [PATCH 04/22] Fix detection of default call function --- src/models.js | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/src/models.js b/src/models.js index 5b0fc9b1d..f8e082287 100644 --- a/src/models.js +++ b/src/models.js @@ -639,19 +639,6 @@ export class PreTrainedModel extends Callable { return new this(...info); } - /** - * Whether this model can generate sequences with `.generate()`. - * @returns {boolean} `true` if this model can generate sequences, `false` otherwise. - */ - can_generate() { - if (!('generate' in this)) return false; - - if (this.config.is_encoder_decoder) { - return 'decoder_merged_session' in this; - } else { - return 'session' in this; - } - } /** * Runs the model with the provided inputs * @param {Object} model_inputs Object containing input tensors @@ -659,8 +646,8 @@ export class PreTrainedModel extends Callable { */ async _call(model_inputs) { // TODO: prepare inputs - if (this.config.is_encoder_decoder || !('generate' in this)) { - // Either encoder-decoder or encoder-only model. + if (this.config.is_encoder_decoder || !('decoder_merged_session' in this)) { + // Either encoder-decoder or encoder-only model. In both cases, call the encoder. return await sessionRun(this.session, model_inputs); } else { // Decoder-only model. From f4759c855b70e8677be7904bd761de00c846f5a1 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 15 Jun 2023 00:00:11 +0200 Subject: [PATCH 05/22] Update default `_call` JSDoc --- src/models.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models.js b/src/models.js index f8e082287..a1881f625 100644 --- a/src/models.js +++ b/src/models.js @@ -642,7 +642,7 @@ export class PreTrainedModel extends Callable { /** * Runs the model with the provided inputs * @param {Object} model_inputs Object containing input tensors - * @returns {Promise} Object containing output tensors + * @returns {Promise} Object containing output tensors */ async _call(model_inputs) { // TODO: prepare inputs From 4236b6ce8b02f701075d32d6a83f13352470037c Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 15 Jun 2023 00:02:58 +0200 Subject: [PATCH 06/22] Mark helper functions as private --- src/models.js | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/models.js b/src/models.js index a1881f625..7979d7ee3 100644 --- a/src/models.js +++ b/src/models.js @@ -86,6 +86,7 @@ const { InferenceSession, Tensor: ONNXTensor } = ONNX; * @param {string} fileName The name of the model file. * @param {PretrainedOptions} options Additional options for loading the model. * @returns {Promise} A Promise that resolves to an InferenceSession object. + * @private */ async function constructSession(pretrained_model_name_or_path, fileName, options) { // TODO add option for user to force specify their desired execution provider @@ -122,6 +123,7 @@ async function constructSession(pretrained_model_name_or_path, fileName, options * @param {InferenceSession} session The InferenceSession object to run. * @param {Object} inputs An object that maps input names to input tensors. * @returns {Promise} A Promise that resolves to an object that maps output names to output tensors. + * @private */ async function sessionRun(session, inputs) { @@ -166,6 +168,7 @@ async function sessionRun(session, inputs) { * Replaces ONNX Tensor objects with custom Tensor objects to support additional functions. * @param {Object} obj The object to replace tensor objects in. * @returns {Object} The object with tensor objects replaced by custom Tensor objects. + * @private */ function replaceTensors(obj) { // Convert ONNX Tensors with our custom Tensor class @@ -219,6 +222,7 @@ function toI64Tensor(items) { * @param {Object} self The calling object instance. * @param {Tensor} tokens The input tokens. * @returns {Tensor} The attention mask tensor. + * @private */ function _prepare_attention_mask(self, tokens) { @@ -251,6 +255,7 @@ function _prepare_attention_mask(self, tokens) { * Creates a boolean tensor with a single value. * @param {boolean} value The value of the tensor. * @returns {Tensor} The boolean tensor. + * @private */ function boolTensor(value) { // Create boolean tensor @@ -263,6 +268,7 @@ function boolTensor(value) { * @param {string} pretrained_model_name_or_path The path to the model directory. * @param {PretrainedOptions} options Additional options for loading the model. * @returns {Promise} A promise that resolves with information about the loaded model. + * @private */ async function loadAutoModel(pretrained_model_name_or_path, options) { let config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options) @@ -279,6 +285,7 @@ async function loadAutoModel(pretrained_model_name_or_path, options) { * @param {string} pretrained_model_name_or_path The path to the model directory. * @param {PretrainedOptions} options Additional options for loading the model. * @returns {Promise} A promise that resolves with information about the loaded model. + * @private */ async function loadModel(pretrained_model_name_or_path, options) { let info = await Promise.all([ @@ -293,6 +300,7 @@ async function loadModel(pretrained_model_name_or_path, options) { * @param {string} pretrained_model_name_or_path The path to the model directory. * @param {PretrainedOptions} options Additional options for loading the model. * @returns {Promise} A promise that resolves with information about the loaded model. + * @private */ async function seq2seqLoadModel(pretrained_model_name_or_path, options) { let info = await Promise.all([ @@ -309,6 +317,7 @@ async function seq2seqLoadModel(pretrained_model_name_or_path, options) { * @param {string} pretrained_model_name_or_path The path to the model directory. * @param {PretrainedOptions} options Additional options for loading the model. * @returns {Promise} A promise that resolves with information about the loaded model. + * @private */ async function encoderDecoderLoadModel(pretrained_model_name_or_path, options) { let info = await Promise.all([ @@ -325,6 +334,7 @@ async function encoderDecoderLoadModel(pretrained_model_name_or_path, options) { * @param {string} pretrained_model_name_or_path The path to the model directory. * @param {PretrainedOptions} options Additional options for loading the model. * @returns {Promise} A promise that resolves with information about the loaded model. + * @private */ async function decoderLoadModel(pretrained_model_name_or_path, options) { let info = await Promise.all([ @@ -343,6 +353,7 @@ async function decoderLoadModel(pretrained_model_name_or_path, options) { * @param {string} [options.encoder_input_name='input_ids'] The name of the input tensor for the encoder. * @param {boolean} [options.add_decoder_pkv=true] Flag to add the decoder past key values. * @returns {Promise} Promise that resolves with the output of the seq2seq model. + * @private */ async function seq2seq_forward(self, model_inputs, { encoder_input_name = 'input_ids', @@ -386,6 +397,7 @@ async function seq2seq_forward(self, model_inputs, { * @param {number} numOutputTokens The maximum number of output tokens for the model. * @param {boolean} [requires_attention_mask=true] Flag to indicate if the model requires an attention mask. * @returns {Object[]} Array of beam search objects. + * @private */ function seq2seqStartBeams(self, inputTokenIds, numOutputTokens, requires_attention_mask = true) { let beams = []; @@ -432,6 +444,7 @@ function seq2seqStartBeams(self, inputTokenIds, numOutputTokens, requires_attent * @param {Object} options options * @param {string} [options.input_name='input_ids'] The name of the input tensor for the encoder. * @returns {Promise} Promise that resolves with the output of the seq2seq model for the given beam. + * @private */ async function seq2seqRunBeam(self, beam, { input_name = 'input_ids', @@ -463,6 +476,7 @@ async function seq2seqRunBeam(self, beam, { * @param {Object} self The text generation model object. * @param {Object} model_inputs The input data to be used for the forward pass. * @returns {Promise} Promise that resolves with an object containing the logits and past key values. + * @private */ async function textgen_forward(self, model_inputs) { let past_key_values = model_inputs.past_key_values; @@ -487,6 +501,7 @@ async function textgen_forward(self, model_inputs) { * @param {number} numOutputTokens The maximum number of tokens to generate for each beam. * @param {Tensor} [inputs_attention_mask] The attention mask tensor for the input token IDs. * @returns {Object[]} An array of beams initialized with the given inputs and parameters. + * @private */ function textgenStartBeams(self, inputTokenIds, numOutputTokens, inputs_attention_mask) { let beams = []; @@ -537,6 +552,7 @@ function textgenStartBeams(self, inputTokenIds, numOutputTokens, inputs_attentio * @param {Object} beam.past_key_values The past key values. * @param {number[]} beam.output_token_ids The output token ids. * @returns {Promise} The output of the generation step. + * @private */ async function textgenRunBeam(self, beam) { let attnMaskData = new BigInt64Array(beam.input.data.length + beam.output_token_ids.length).fill(1n) @@ -565,6 +581,7 @@ async function textgenRunBeam(self, beam) { * Update a beam with a new token ID. * @param {Object} beam The beam to update. * @param {number} newTokenId The new token ID to add to the beam's output. + * @private */ function textgenUpdatebeam(beam, newTokenId) { beam.output_token_ids = [...beam.output_token_ids, newTokenId]; From 0aa9b01c60ee177a86514dbf68372b2d9b3b8e63 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 15 Jun 2023 00:11:49 +0200 Subject: [PATCH 07/22] Remove outdated comments --- src/models.js | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/models.js b/src/models.js index 7979d7ee3..45ebaf013 100644 --- a/src/models.js +++ b/src/models.js @@ -171,8 +171,6 @@ async function sessionRun(session, inputs) { * @private */ function replaceTensors(obj) { - // Convert ONNX Tensors with our custom Tensor class - // to support additional functions for (let prop in obj) { if (obj[prop] instanceof ONNXTensor) { obj[prop] = new Tensor(obj[prop]); @@ -258,7 +256,6 @@ function _prepare_attention_mask(self, tokens) { * @private */ function boolTensor(value) { - // Create boolean tensor return new Tensor('bool', [value], [1]); } From a41126eb8b6620fde9dc2259648977b7b324b55f Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 15 Jun 2023 00:13:51 +0200 Subject: [PATCH 08/22] Fix JSDoc --- src/models.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models.js b/src/models.js index 45ebaf013..fc59e7773 100644 --- a/src/models.js +++ b/src/models.js @@ -656,7 +656,7 @@ export class PreTrainedModel extends Callable { /** * Runs the model with the provided inputs * @param {Object} model_inputs Object containing input tensors - * @returns {Promise} Object containing output tensors + * @returns {Promise} Object containing output tensors */ async _call(model_inputs) { // TODO: prepare inputs From 6df5b04624c9c6500949c6f28c74b4399ead3615 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 15 Jun 2023 00:21:35 +0200 Subject: [PATCH 09/22] Rename functions --- src/models.js | 58 ++++++++++++++++++++++++++------------------------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/src/models.js b/src/models.js index fc59e7773..e27c974f7 100644 --- a/src/models.js +++ b/src/models.js @@ -352,7 +352,7 @@ async function decoderLoadModel(pretrained_model_name_or_path, options) { * @returns {Promise} Promise that resolves with the output of the seq2seq model. * @private */ -async function seq2seq_forward(self, model_inputs, { +async function seq2seqForward(self, model_inputs, { encoder_input_name = 'input_ids', add_decoder_pkv = true } = {}) { @@ -469,13 +469,13 @@ async function seq2seqRunBeam(self, beam, { } /** - * Forward pass of the text generation model. - * @param {Object} self The text generation model object. + * Forward pass of a decoder model. + * @param {Object} self The decoder model. * @param {Object} model_inputs The input data to be used for the forward pass. * @returns {Promise} Promise that resolves with an object containing the logits and past key values. * @private */ -async function textgen_forward(self, model_inputs) { +async function decoderForward(self, model_inputs) { let past_key_values = model_inputs.past_key_values; let decoderFeeds = { input_ids: model_inputs.input_ids, @@ -500,7 +500,7 @@ async function textgen_forward(self, model_inputs) { * @returns {Object[]} An array of beams initialized with the given inputs and parameters. * @private */ -function textgenStartBeams(self, inputTokenIds, numOutputTokens, inputs_attention_mask) { +function decoderStartBeams(self, inputTokenIds, numOutputTokens, inputs_attention_mask) { let beams = []; let beamId = 0; @@ -541,7 +541,7 @@ function textgenStartBeams(self, inputTokenIds, numOutputTokens, inputs_attentio /** * Runs a single step of the text generation process for a given beam. * - * @param {Object} self The textgen object. + * @param {Object} self The decoder object. * @param {Object} beam The beam to run. * @param {Tensor} beam.input The input tensor. * @param {Tensor} beam.model_input_ids The input ids to the model. @@ -551,7 +551,7 @@ function textgenStartBeams(self, inputTokenIds, numOutputTokens, inputs_attentio * @returns {Promise} The output of the generation step. * @private */ -async function textgenRunBeam(self, beam) { +async function decoderRunBeam(self, beam) { let attnMaskData = new BigInt64Array(beam.input.data.length + beam.output_token_ids.length).fill(1n) // 1. Prepare @@ -580,7 +580,7 @@ async function textgenRunBeam(self, beam) { * @param {number} newTokenId The new token ID to add to the beam's output. * @private */ -function textgenUpdatebeam(beam, newTokenId) { +function decoderUpdatebeam(beam, newTokenId) { beam.output_token_ids = [...beam.output_token_ids, newTokenId]; beam.model_input_ids = new Tensor('int64', [BigInt(newTokenId)], [1, 1]); } @@ -665,7 +665,7 @@ export class PreTrainedModel extends Callable { return await sessionRun(this.session, model_inputs); } else { // Decoder-only model. - return await textgen_forward(this, model_inputs); + return await decoderForward(this, model_inputs); } } @@ -677,6 +677,8 @@ export class PreTrainedModel extends Callable { * @throws {Error} This method must be implemented in subclasses. */ async forward(model_inputs) { + sessionRun(); + throw Error("forward should be implemented in subclasses.") } @@ -1467,7 +1469,7 @@ export class T5ForConditionalGeneration extends T5PreTrainedModel { * @returns {Promise} The model output. */ async forward(model_inputs) { - return await seq2seq_forward(this, model_inputs); + return await seq2seqForward(this, model_inputs); } } ////////////////////////////////////////////////// @@ -1581,7 +1583,7 @@ export class MT5ForConditionalGeneration extends MT5PreTrainedModel { * @returns {Promise} A Promise that resolves to the model outputs. */ async forward(model_inputs) { - return await seq2seq_forward(this, model_inputs); + return await seq2seqForward(this, model_inputs); } } ////////////////////////////////////////////////// @@ -1698,7 +1700,7 @@ export class BartForConditionalGeneration extends BartPretrainedModel { * @returns {Promise} The model output. */ async forward(model_inputs) { - return await seq2seq_forward(this, model_inputs); + return await seq2seqForward(this, model_inputs); } } @@ -1917,7 +1919,7 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { * @returns {Promise} The model output. */ async forward(model_inputs) { - return await seq2seq_forward(this, model_inputs, { + return await seq2seqForward(this, model_inputs, { encoder_input_name: 'input_features', }); } @@ -2013,7 +2015,7 @@ export class VisionEncoderDecoderModel extends PreTrainedModel { * @returns {Promise} The output tensor of the model. */ async forward(model_inputs) { - return await seq2seq_forward(this, model_inputs, { + return await seq2seqForward(this, model_inputs, { encoder_input_name: 'pixel_values', add_decoder_pkv: false }) @@ -2082,7 +2084,7 @@ export class GPT2LMHeadModel extends GPT2PreTrainedModel { * @returns {any} A Beam object representing the initialized beam. */ getStartBeams(inputTokenIds, numOutputTokens, inputs_attention_mask) { - return textgenStartBeams(this, inputTokenIds, numOutputTokens, inputs_attention_mask) + return decoderStartBeams(this, inputTokenIds, numOutputTokens, inputs_attention_mask) } /** @@ -2091,7 +2093,7 @@ export class GPT2LMHeadModel extends GPT2PreTrainedModel { * @returns {Promise} The updated beam after a single generation step. */ async runBeam(beam) { - return await textgenRunBeam(this, beam); + return await decoderRunBeam(this, beam); } /** @@ -2100,7 +2102,7 @@ export class GPT2LMHeadModel extends GPT2PreTrainedModel { * @param {number} newTokenId The new generated token id to be added to the beam. */ updateBeam(beam, newTokenId) { - return textgenUpdatebeam(beam, newTokenId); + return decoderUpdatebeam(beam, newTokenId); } /** @@ -2109,7 +2111,7 @@ export class GPT2LMHeadModel extends GPT2PreTrainedModel { * @returns {Promise} The output tensor of the model. */ async forward(model_inputs) { - return await textgen_forward(this, model_inputs) + return await decoderForward(this, model_inputs) } } @@ -2159,7 +2161,7 @@ export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { * @returns {any} A Beam object representing the initialized beam. */ getStartBeams(inputTokenIds, numOutputTokens, inputs_attention_mask) { - return textgenStartBeams(this, inputTokenIds, numOutputTokens, inputs_attention_mask) + return decoderStartBeams(this, inputTokenIds, numOutputTokens, inputs_attention_mask) } /** @@ -2168,7 +2170,7 @@ export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { * @returns {Promise} The updated beam after a single generation step. */ async runBeam(beam) { - return await textgenRunBeam(this, beam); + return await decoderRunBeam(this, beam); } /** @@ -2177,7 +2179,7 @@ export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { * @param {number} newTokenId The new generated token id to be added to the beam. */ updateBeam(beam, newTokenId) { - return textgenUpdatebeam(beam, newTokenId); + return decoderUpdatebeam(beam, newTokenId); } /** @@ -2186,7 +2188,7 @@ export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { * @returns {Promise} The output tensor of the model. */ async forward(model_inputs) { - return await textgen_forward(this, model_inputs) + return await decoderForward(this, model_inputs) } } @@ -2246,7 +2248,7 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { * @returns {any} A Beam object representing the initialized beam. */ getStartBeams(inputTokenIds, numOutputTokens, inputs_attention_mask) { - return textgenStartBeams(this, inputTokenIds, numOutputTokens, inputs_attention_mask) + return decoderStartBeams(this, inputTokenIds, numOutputTokens, inputs_attention_mask) } /** @@ -2255,7 +2257,7 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { * @returns {Promise} The updated beam after a single generation step. */ async runBeam(beam) { - return await textgenRunBeam(this, beam); + return await decoderRunBeam(this, beam); } /** @@ -2264,7 +2266,7 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { * @param {number} newTokenId The new generated token id to be added to the beam. */ updateBeam(beam, newTokenId) { - return textgenUpdatebeam(beam, newTokenId); + return decoderUpdatebeam(beam, newTokenId); } /** @@ -2273,7 +2275,7 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { * @returns {Promise} The output tensor of the model. */ async forward(model_inputs) { - return await textgen_forward(this, model_inputs) + return await decoderForward(this, model_inputs) } } @@ -2482,7 +2484,7 @@ export class MarianMTModel extends MarianPreTrainedModel { * @returns {Promise} */ async forward(model_inputs) { - return await seq2seq_forward(this, model_inputs); + return await seq2seqForward(this, model_inputs); } } ////////////////////////////////////////////////// @@ -2588,7 +2590,7 @@ export class M2M100ForConditionalGeneration extends M2M100PreTrainedModel { * @returns {Promise} */ async forward(model_inputs) { - return await seq2seq_forward(this, model_inputs); + return await seq2seqForward(this, model_inputs); } } ////////////////////////////////////////////////// From 768a842a37baafebf32aca06a7bc87941d337acf Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 15 Jun 2023 02:40:11 +0200 Subject: [PATCH 10/22] Specify model types Reduces major code duplication --- src/models.js | 434 ++++++++++++++------------------------------------ 1 file changed, 119 insertions(+), 315 deletions(-) diff --git a/src/models.js b/src/models.js index e27c974f7..2161ee177 100644 --- a/src/models.js +++ b/src/models.js @@ -78,8 +78,89 @@ const { InferenceSession, Tensor: ONNXTensor } = ONNX; * @typedef {import('./utils/hub.js').PretrainedOptions} PretrainedOptions */ + +////////////////////////////////////////////////// +// Model types: used internally +class ModelType { }; + +// Either encoder-only or encoder-decoder (and will be decided by `model.config.is_encoder_decoder`) +class EncoderModelType extends ModelType { }; +class EncoderOnlyModelType extends EncoderModelType { }; +class EncoderDecoderModelType extends EncoderModelType { }; +class Seq2SeqModelType extends EncoderDecoderModelType { }; +class DecoderOnlyModelType extends ModelType { }; +////////////////////////////////////////////////// + + ////////////////////////////////////////////////// // Helper functions +function issubclass(a, b) { + return a === b || a.prototype instanceof b; +} +/** + * Loads an model from the specified path. + * @param {Object} cls The class of the model. + * @param {string} pretrained_model_name_or_path The path to the model directory. + * @param {PretrainedOptions} options Additional options for loading the model. + * @returns {Promise} A promise that resolves with information about the loaded model. + * @private + */ +async function load(cls, pretrained_model_name_or_path, options) { + const cModelType = cls.MODEL_TYPE; + if (cls.MODEL_TYPE === null) { + throw new Error("`MODEL_TYPE` not implemented for this model."); + } + if (issubclass(cModelType, DecoderOnlyModelType)) { + return await Promise.all([ + AutoConfig.from_pretrained(pretrained_model_name_or_path, options), + constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), + ]); + + } else if (issubclass(cModelType, Seq2SeqModelType)) { + return await Promise.all([ + AutoConfig.from_pretrained(pretrained_model_name_or_path, options), + constructSession(pretrained_model_name_or_path, 'encoder_model', options), + constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), + getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), + ]); + + } else if (issubclass(cModelType, EncoderDecoderModelType)) { + return await Promise.all([ + AutoConfig.from_pretrained(pretrained_model_name_or_path, options), + constructSession(pretrained_model_name_or_path, 'encoder_model', options), + constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), + ]); + + } else if (issubclass(cModelType, EncoderOnlyModelType)) { + return await Promise.all([ + AutoConfig.from_pretrained(pretrained_model_name_or_path, options), + constructSession(pretrained_model_name_or_path, 'model', options) + ]); + + } else if (issubclass(cModelType, EncoderModelType)) { + let config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options) + let modelName = config.is_encoder_decoder ? 'encoder_model' : 'model'; + let session = await constructSession(pretrained_model_name_or_path, modelName, options); + return [config, session]; + } else { + throw Error(`Unable to determine model type: ${cModelType?.constructor?.name}`); + } +} + +/** + * Helper function to determine which `call` method to run for a specific model. + * @param {Object} self The calling object + * @param {Object} model_inputs The inputs to be sent to the model + * @returns {Promise} The model output + */ +async function call(self, model_inputs) { + if (issubclass(self.constructor.MODEL_TYPE, DecoderOnlyModelType)) { + return await decoderForward(self, model_inputs); + } else { + return await sessionRun(self.session, model_inputs); + } +} + /** * Constructs an InferenceSession using a model file located at the specified path. * @param {string} pretrained_model_name_or_path The path to the directory containing the model file. @@ -222,7 +303,7 @@ function toI64Tensor(items) { * @returns {Tensor} The attention mask tensor. * @private */ -function _prepare_attention_mask(self, tokens) { +function prepareAttentionMask(self, tokens) { // Prepare attention mask let pad_token_id = self.config.pad_token_id ?? null; @@ -259,88 +340,6 @@ function boolTensor(value) { return new Tensor('bool', [value], [1]); } - -/** - * Loads a model from the specified path. - * @param {string} pretrained_model_name_or_path The path to the model directory. - * @param {PretrainedOptions} options Additional options for loading the model. - * @returns {Promise} A promise that resolves with information about the loaded model. - * @private - */ -async function loadAutoModel(pretrained_model_name_or_path, options) { - let config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options) - - let modelName = config.is_encoder_decoder ? 'encoder_model' : 'model'; - - let session = await constructSession(pretrained_model_name_or_path, modelName, options); - - return [config, session]; -} - -/** - * Loads a model from the specified path. - * @param {string} pretrained_model_name_or_path The path to the model directory. - * @param {PretrainedOptions} options Additional options for loading the model. - * @returns {Promise} A promise that resolves with information about the loaded model. - * @private - */ -async function loadModel(pretrained_model_name_or_path, options) { - let info = await Promise.all([ - AutoConfig.from_pretrained(pretrained_model_name_or_path, options), - constructSession(pretrained_model_name_or_path, 'model', options) - ]); - return info; -} - -/** - * Loads a sequence-to-sequence model from the specified path. - * @param {string} pretrained_model_name_or_path The path to the model directory. - * @param {PretrainedOptions} options Additional options for loading the model. - * @returns {Promise} A promise that resolves with information about the loaded model. - * @private - */ -async function seq2seqLoadModel(pretrained_model_name_or_path, options) { - let info = await Promise.all([ - AutoConfig.from_pretrained(pretrained_model_name_or_path, options), - constructSession(pretrained_model_name_or_path, 'encoder_model', options), - constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), - getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), - ]); - return info; -} - -/** - * Loads an encoder-decoder model from the specified path. - * @param {string} pretrained_model_name_or_path The path to the model directory. - * @param {PretrainedOptions} options Additional options for loading the model. - * @returns {Promise} A promise that resolves with information about the loaded model. - * @private - */ -async function encoderDecoderLoadModel(pretrained_model_name_or_path, options) { - let info = await Promise.all([ - AutoConfig.from_pretrained(pretrained_model_name_or_path, options), - constructSession(pretrained_model_name_or_path, 'encoder_model', options), - constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), - ]) - return info; -} - - -/** - * Loads a decoder model from the specified path. - * @param {string} pretrained_model_name_or_path The path to the model directory. - * @param {PretrainedOptions} options Additional options for loading the model. - * @returns {Promise} A promise that resolves with information about the loaded model. - * @private - */ -async function decoderLoadModel(pretrained_model_name_or_path, options) { - let info = await Promise.all([ - AutoConfig.from_pretrained(pretrained_model_name_or_path, options), - constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), - ]) - return info; -} - // JS doesn't support mixins, so we define some reused functions here, and allow "this" to be passed in /** * Perform forward pass on the seq2seq model. @@ -359,7 +358,7 @@ async function seq2seqForward(self, model_inputs, { let encoderOutputs = model_inputs.encoder_outputs; let pastKeyValues = model_inputs.past_key_values; - if (encoderOutputs === null) { + if (!encoderOutputs) { const encoderFeeds = { [encoder_input_name]: model_inputs[encoder_input_name], } @@ -425,7 +424,7 @@ function seq2seqStartBeams(self, inputTokenIds, numOutputTokens, requires_attent } if (requires_attention_mask) { - start.attention_mask = _prepare_attention_mask(self, tokens); + start.attention_mask = prepareAttentionMask(self, tokens); } beams.push(start); @@ -516,7 +515,7 @@ function decoderStartBeams(self, inputTokenIds, numOutputTokens, inputs_attentio attn_mask.dims = [1, ...attn_mask.dims] } else { - attn_mask = _prepare_attention_mask(self, tokens) + attn_mask = prepareAttentionMask(self, tokens) } let start = { @@ -592,7 +591,7 @@ function decoderUpdatebeam(beam, newTokenId) { * @extends Callable */ export class PreTrainedModel extends Callable { - static LOAD_FUNCTION = loadAutoModel; + static MODEL_TYPE = EncoderModelType; /** * Creates a new instance of the `PreTrainedModel` class. @@ -640,7 +639,7 @@ export class PreTrainedModel extends Callable { local_files_only = false, revision = 'main', } = {}) { - let info = await this.LOAD_FUNCTION(pretrained_model_name_or_path, { + let info = await load(this, pretrained_model_name_or_path, { quantized, progress_callback, config, @@ -659,14 +658,7 @@ export class PreTrainedModel extends Callable { * @returns {Promise} Object containing output tensors */ async _call(model_inputs) { - // TODO: prepare inputs - if (this.config.is_encoder_decoder || !('decoder_merged_session' in this)) { - // Either encoder-decoder or encoder-only model. In both cases, call the encoder. - return await sessionRun(this.session, model_inputs); - } else { - // Decoder-only model. - return await decoderForward(this, model_inputs); - } + return await call(this, model_inputs); } /** @@ -677,8 +669,6 @@ export class PreTrainedModel extends Callable { * @throws {Error} This method must be implemented in subclasses. */ async forward(model_inputs) { - sessionRun(); - throw Error("forward should be implemented in subclasses.") } @@ -1386,6 +1376,8 @@ export class T5Model extends T5PreTrainedModel { * @extends T5PreTrainedModel */ export class T5ForConditionalGeneration extends T5PreTrainedModel { + static MODEL_TYPE = Seq2SeqModelType; + /** * Creates a new instance of the `T5ForConditionalGeneration` class. * @param {Object} config The model configuration. @@ -1407,34 +1399,6 @@ export class T5ForConditionalGeneration extends T5PreTrainedModel { this.encoder_dim_kv = this.config.d_kv; } - /** - * Loads a pre-trained model from the given `pretrained_model_name_or_path`. - * - * @param {string} pretrained_model_name_or_path The path to the pre-trained model. - * @param {PretrainedOptions} options Additional options for loading the model. For more information, @see {@link PreTrainedModel.from_pretrained}. - * - * @returns {Promise} A new instance of the `T5ForConditionalGeneration` class. - */ - static async from_pretrained(pretrained_model_name_or_path, { - quantized = true, - progress_callback = null, - config = null, - cache_dir = null, - local_files_only = false, - revision = 'main', - } = {}) { - let info = await seq2seqLoadModel(pretrained_model_name_or_path, { - quantized, - progress_callback, - config, - cache_dir, - local_files_only, - revision, - }); - // @ts-ignore - return new this(...info); - } - /** * Generates the start beams for a given set of inputs and output length. * @param {number[][]} inputs The input token IDs. @@ -1498,6 +1462,8 @@ export class MT5Model extends MT5PreTrainedModel { * @extends MT5PreTrainedModel */ export class MT5ForConditionalGeneration extends MT5PreTrainedModel { + static MODEL_TYPE = Seq2SeqModelType; + /** * Creates a new instance of the `MT5ForConditionalGeneration` class. * @param {any} config The model configuration. @@ -1519,34 +1485,6 @@ export class MT5ForConditionalGeneration extends MT5PreTrainedModel { this.encoder_dim_kv = this.config.d_kv; } - /** - * Loads a pre-trained model from the given `pretrained_model_name_or_path`. - * - * @param {string} pretrained_model_name_or_path The path to the pre-trained model. - * @param {PretrainedOptions} options Additional options for loading the model. For more information, @see {@link PreTrainedModel.from_pretrained}. - * - * @returns {Promise} A new instance of the `MT5ForConditionalGeneration` class. - */ - static async from_pretrained(pretrained_model_name_or_path, { - quantized = true, - progress_callback = null, - config = null, - cache_dir = null, - local_files_only = false, - revision = 'main', - } = {}) { - let info = await seq2seqLoadModel(pretrained_model_name_or_path, { - quantized, - progress_callback, - config, - cache_dir, - local_files_only, - revision, - }); - // @ts-ignore - return new this(...info); - } - /** * Generates the start beams for the given input tokens and output sequence length. * @@ -1617,6 +1555,8 @@ export class BartModel extends BartPretrainedModel { * @extends BartPretrainedModel */ export class BartForConditionalGeneration extends BartPretrainedModel { + static MODEL_TYPE = Seq2SeqModelType; + /** * Creates a new instance of the `BartForConditionalGeneration` class. * @param {Object} config The configuration object for the Bart model. @@ -1637,33 +1577,6 @@ export class BartForConditionalGeneration extends BartPretrainedModel { this.num_encoder_heads = this.config.encoder_attention_heads; this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads; } - /** - * Loads a pre-trained model from the given `pretrained_model_name_or_path`. - * - * @param {string} pretrained_model_name_or_path The path to the pre-trained model. - * @param {PretrainedOptions} options Additional options for loading the model. For more information, @see {@link PreTrainedModel.from_pretrained}. - * - * @returns {Promise} A new instance of the `BartForConditionalGeneration` class. - */ - static async from_pretrained(pretrained_model_name_or_path, { - quantized = true, - progress_callback = null, - config = null, - cache_dir = null, - local_files_only = false, - revision = 'main', - } = {}) { - let info = await seq2seqLoadModel(pretrained_model_name_or_path, { - quantized, - progress_callback, - config, - cache_dir, - local_files_only, - revision, - }); - // @ts-ignore - return new this(...info); - } /** * Returns the initial beam for generating output text. @@ -1803,6 +1716,8 @@ export class WhisperModel extends WhisperPreTrainedModel { * @extends WhisperPreTrainedModel */ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { + static MODEL_TYPE = Seq2SeqModelType; + /** * Creates a new instance of the `WhisperForConditionalGeneration` class. * @param {Object} config Configuration object for the model. @@ -1854,34 +1769,6 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { return super.generate(inputs, generation_config, logits_processor) } - /** - * Loads a pre-trained model from the given `pretrained_model_name_or_path`. - * - * @param {string} pretrained_model_name_or_path The path to the pre-trained model. - * @param {PretrainedOptions} options Additional options for loading the model. For more information, @see {@link PreTrainedModel.from_pretrained}. - * - * @returns {Promise} A new instance of the `WhisperForConditionalGeneration` class. - */ - static async from_pretrained(pretrained_model_name_or_path, { - quantized = true, - progress_callback = null, - config = null, - cache_dir = null, - local_files_only = false, - revision = 'main', - } = {}) { - let info = await seq2seqLoadModel(pretrained_model_name_or_path, { - quantized, - progress_callback, - config, - cache_dir, - local_files_only, - revision, - }); - // @ts-ignore - return new this(...info); - } - /** * Gets the start beams for generating outputs. * @param {Array} inputTokenIds Array of input token IDs. @@ -1947,34 +1834,6 @@ export class VisionEncoderDecoderModel extends PreTrainedModel { this.dim_kv = this.config.decoder.n_embd / this.num_heads; } - /** - * Loads a pre-trained model from the given `pretrained_model_name_or_path`. - * - * @param {string} pretrained_model_name_or_path The path to the pre-trained model. - * @param {PretrainedOptions} options Additional options for loading the model. For more information, @see {@link PreTrainedModel.from_pretrained}. - * - * @returns {Promise} A new instance of the `VisionEncoderDecoderModel` class. - */ - static async from_pretrained(pretrained_model_name_or_path, { - quantized = true, - progress_callback = null, - config = null, - cache_dir = null, - local_files_only = false, - revision = 'main', - } = {}) { - let info = await encoderDecoderLoadModel(pretrained_model_name_or_path, { - quantized, - progress_callback, - config, - cache_dir, - local_files_only, - revision, - }); - // @ts-ignore - return new this(...info); - } - /** * Generate beam search outputs for the given input pixels and number of output tokens. * @@ -2058,7 +1917,7 @@ export class GPT2Model extends GPT2PreTrainedModel { * @extends GPT2PreTrainedModel */ export class GPT2LMHeadModel extends GPT2PreTrainedModel { - static LOAD_FUNCTION = decoderLoadModel; + static MODEL_TYPE = DecoderOnlyModelType; /** * Creates a new instance of the `GPT2LMHeadModel` class. @@ -2135,7 +1994,7 @@ export class GPTNeoModel extends GPTNeoPreTrainedModel { } export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { - static LOAD_FUNCTION = decoderLoadModel; + static MODEL_TYPE = DecoderOnlyModelType; /** * Creates a new instance of the `GPTNeoForCausalLM` class. @@ -2222,7 +2081,7 @@ export class CodeGenModel extends CodeGenPreTrainedModel { * @extends CodeGenPreTrainedModel */ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { - static LOAD_FUNCTION = decoderLoadModel; + static MODEL_TYPE = DecoderOnlyModelType; /** * Creates a new instance of the `CodeGenForCausalLM` class. @@ -2402,6 +2261,8 @@ export class MarianModel extends MarianPreTrainedModel { } export class MarianMTModel extends MarianPreTrainedModel { + static MODEL_TYPE = Seq2SeqModelType; + /** * Creates a new instance of the `MarianMTModel` class. * @param {Object} config The model configuration object. @@ -2423,34 +2284,6 @@ export class MarianMTModel extends MarianPreTrainedModel { this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads; } - /** - * Loads a pre-trained model from the given `pretrained_model_name_or_path`. - * - * @param {string} pretrained_model_name_or_path The path to the pre-trained model. - * @param {PretrainedOptions} options Additional options for loading the model. For more information, @see {@link PreTrainedModel.from_pretrained}. - * - * @returns {Promise} A new instance of the `MarianMTModel` class. - */ - static async from_pretrained(pretrained_model_name_or_path, { - quantized = true, - progress_callback = null, - config = null, - cache_dir = null, - local_files_only = false, - revision = 'main', - } = {}) { - let info = await seq2seqLoadModel(pretrained_model_name_or_path, { - quantized, - progress_callback, - config, - cache_dir, - local_files_only, - revision, - }); - // @ts-ignore - return new this(...info); - } - /** * Initializes and returns the beam for text generation task * @param {any[]} inputs The input token ids. @@ -2508,6 +2341,8 @@ export class M2M100Model extends M2M100PreTrainedModel { } export class M2M100ForConditionalGeneration extends M2M100PreTrainedModel { + static MODEL_TYPE = Seq2SeqModelType; + /** * Creates a new instance of the `M2M100ForConditionalGeneration` class. * @param {Object} config The model configuration object. @@ -2529,33 +2364,6 @@ export class M2M100ForConditionalGeneration extends M2M100PreTrainedModel { this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads; } - /** - * Loads a pre-trained model from the given `pretrained_model_name_or_path`. - * - * @param {string} pretrained_model_name_or_path The path to the pre-trained model. - * @param {PretrainedOptions} options Additional options for loading the model. For more information, @see {@link PreTrainedModel.from_pretrained}. - * - * @returns {Promise} A new instance of the `M2M100ForConditionalGeneration` class. - */ - static async from_pretrained(pretrained_model_name_or_path, { - quantized = true, - progress_callback = null, - config = null, - cache_dir = null, - local_files_only = false, - revision = 'main', - } = {}) { - let info = await seq2seqLoadModel(pretrained_model_name_or_path, { - quantized, - progress_callback, - config, - cache_dir, - local_files_only, - revision, - }); - // @ts-ignore - return new this(...info); - } /** * Initializes and returns the beam for text generation task @@ -2617,9 +2425,9 @@ export class PretrainedMixin { static BASE_IF_FAIL = false; /** - * The function to use to load the pretrained model. + * The type of model. We set this to null since we require each `AutoModel` to override this. */ - static LOAD_FUNCTION = null; + static MODEL_TYPE = null; /** * Instantiate one of the model classes of the library from a pretrained model. @@ -2634,7 +2442,7 @@ export class PretrainedMixin { * - A path to a *directory* containing model weights, e.g., `./my_model_directory/`. * @param {PretrainedOptions} options Additional options for loading the model. * - * @returns {Promise} A new instance of the PreTrainedModel class. + * @returns {Promise} A new instance of the `PreTrainedModel` class. */ static async from_pretrained(pretrained_model_name_or_path, { quantized = true, @@ -2644,11 +2452,7 @@ export class PretrainedMixin { local_files_only = false, revision = 'main', } = {}) { - if (this.LOAD_FUNCTION === null) { - throw new Error("`LOAD_FUNCTION` not implemented for this model"); - } - - let info = await this.LOAD_FUNCTION(pretrained_model_name_or_path, { + let info = await load(this, pretrained_model_name_or_path, { quantized, progress_callback, config, @@ -2678,7 +2482,7 @@ export class PretrainedMixin { * let model = await AutoModel.from_pretrained('bert-base-uncased'); */ export class AutoModel extends PretrainedMixin { - static LOAD_FUNCTION = loadAutoModel; + static MODEL_TYPE = EncoderModelType; // Assume to be an encoder-only or encoder-decoder model. static BASE_IF_FAIL = true; static MODEL_CLASS_MAPPING = { 'bert': BertModel, @@ -2709,7 +2513,7 @@ export class AutoModel extends PretrainedMixin { * let model = await AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased-finetuned-sst-2-english'); */ export class AutoModelForSequenceClassification extends PretrainedMixin { - static LOAD_FUNCTION = loadModel; + static MODEL_TYPE = EncoderOnlyModelType; static MODEL_CLASS_MAPPING = { 'bert': BertForSequenceClassification, 'albert': AlbertForSequenceClassification, @@ -2729,7 +2533,7 @@ export class AutoModelForSequenceClassification extends PretrainedMixin { * let model = await AutoModelForTokenClassification.from_pretrained('Davlan/distilbert-base-multilingual-cased-ner-hrl'); */ export class AutoModelForTokenClassification extends PretrainedMixin { - static LOAD_FUNCTION = loadModel; + static MODEL_TYPE = EncoderOnlyModelType; static MODEL_CLASS_MAPPING = { 'bert': BertForTokenClassification, 'distilbert': DistilBertForTokenClassification, @@ -2745,7 +2549,7 @@ export class AutoModelForTokenClassification extends PretrainedMixin { * let model = await AutoModelForSeq2SeqLM.from_pretrained('t5-small'); */ export class AutoModelForSeq2SeqLM extends PretrainedMixin { - static LOAD_FUNCTION = seq2seqLoadModel; + static MODEL_TYPE = Seq2SeqModelType; static MODEL_CLASS_MAPPING = { 't5': T5ForConditionalGeneration, 'mt5': MT5ForConditionalGeneration, @@ -2764,7 +2568,7 @@ export class AutoModelForSeq2SeqLM extends PretrainedMixin { * let model = await AutoModelForCausalLM.from_pretrained('gpt2'); */ export class AutoModelForCausalLM extends PretrainedMixin { - static LOAD_FUNCTION = decoderLoadModel; + static MODEL_TYPE = DecoderOnlyModelType; static MODEL_CLASS_MAPPING = { 'gpt2': GPT2LMHeadModel, 'gpt_neo': GPTNeoForCausalLM, @@ -2780,7 +2584,7 @@ export class AutoModelForCausalLM extends PretrainedMixin { * let model = await AutoModelForMaskedLM.from_pretrained('bert-base-uncased'); */ export class AutoModelForMaskedLM extends PretrainedMixin { - static LOAD_FUNCTION = loadAutoModel; + static MODEL_TYPE = EncoderOnlyModelType; static MODEL_CLASS_MAPPING = { 'bert': BertForMaskedLM, 'albert': AlbertForMaskedLM, @@ -2799,7 +2603,7 @@ export class AutoModelForMaskedLM extends PretrainedMixin { * let model = await AutoModelForQuestionAnswering.from_pretrained('distilbert-base-cased-distilled-squad'); */ export class AutoModelForQuestionAnswering extends PretrainedMixin { - static LOAD_FUNCTION = loadModel; + static MODEL_TYPE = EncoderOnlyModelType; static MODEL_CLASS_MAPPING = { 'bert': BertForQuestionAnswering, 'albert': AlbertForQuestionAnswering, @@ -2818,7 +2622,7 @@ export class AutoModelForQuestionAnswering extends PretrainedMixin { * let model = await AutoModelForVision2Seq.from_pretrained('nlpconnect/vit-gpt2-image-captioning'); */ export class AutoModelForVision2Seq extends PretrainedMixin { - static LOAD_FUNCTION = encoderDecoderLoadModel; + static MODEL_TYPE = EncoderDecoderModelType; static MODEL_CLASS_MAPPING = { 'vision-encoder-decoder': VisionEncoderDecoderModel } @@ -2832,7 +2636,7 @@ export class AutoModelForVision2Seq extends PretrainedMixin { * let model = await AutoModelForImageClassification.from_pretrained('google/vit-base-patch16-224'); */ export class AutoModelForImageClassification extends PretrainedMixin { - static LOAD_FUNCTION = loadModel; + static MODEL_TYPE = EncoderOnlyModelType; static MODEL_CLASS_MAPPING = { 'vit': ViTForImageClassification, } @@ -2846,7 +2650,7 @@ export class AutoModelForImageClassification extends PretrainedMixin { * let model = await AutoModelForImageSegmentation.from_pretrained('facebook/detr-resnet-50-panoptic'); */ export class AutoModelForImageSegmentation extends PretrainedMixin { - static LOAD_FUNCTION = loadModel; + static MODEL_TYPE = EncoderOnlyModelType; static MODEL_CLASS_MAPPING = { 'detr': DetrForSegmentation, } @@ -2860,7 +2664,7 @@ export class AutoModelForImageSegmentation extends PretrainedMixin { * let model = await AutoModelForObjectDetection.from_pretrained('facebook/detr-resnet-50'); */ export class AutoModelForObjectDetection extends PretrainedMixin { - static LOAD_FUNCTION = loadModel; + static MODEL_TYPE = EncoderOnlyModelType; static MODEL_CLASS_MAPPING = { 'detr': DetrForObjectDetection, } @@ -2874,7 +2678,7 @@ export class AutoModelForObjectDetection extends PretrainedMixin { * let model = await AutoModelForMaskGeneration.from_pretrained('Xenova/sam-vit-base'); */ export class AutoModelForMaskGeneration extends PretrainedMixin { - static LOAD_FUNCTION = loadModel; + static MODEL_TYPE = EncoderOnlyModelType; static MODEL_CLASS_MAPPING = { 'sam': SamModel, } From dd6e363c0772d00cf4a21dcffbd31fded5045db3 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 15 Jun 2023 17:51:45 +0200 Subject: [PATCH 11/22] Improve model output classes --- src/models.js | 255 ++++++++++++++++++++++++++++---------------------- 1 file changed, 142 insertions(+), 113 deletions(-) diff --git a/src/models.js b/src/models.js index 2161ee177..b3c955235 100644 --- a/src/models.js +++ b/src/models.js @@ -94,9 +94,18 @@ class DecoderOnlyModelType extends ModelType { }; ////////////////////////////////////////////////// // Helper functions + +/** + * Determines whether `a` is a subclass of `b`. + * @param {any} a The first item + * @param {any} b The second item + * @returns Whether `a` is a subclass of `b`. + * @private + */ function issubclass(a, b) { return a === b || a.prototype instanceof b; } + /** * Loads an model from the specified path. * @param {Object} cls The class of the model. @@ -148,16 +157,16 @@ async function load(cls, pretrained_model_name_or_path, options) { } /** - * Helper function to determine which `call` method to run for a specific model. + * Helper function to determine which `forward` method to run for a specific model. * @param {Object} self The calling object * @param {Object} model_inputs The inputs to be sent to the model - * @returns {Promise} The model output + * @returns {Promise} The model output */ -async function call(self, model_inputs) { +async function forward(self, model_inputs) { if (issubclass(self.constructor.MODEL_TYPE, DecoderOnlyModelType)) { return await decoderForward(self, model_inputs); } else { - return await sessionRun(self.session, model_inputs); + return await encoderForward(self, model_inputs); } } @@ -342,7 +351,7 @@ function boolTensor(value) { // JS doesn't support mixins, so we define some reused functions here, and allow "this" to be passed in /** - * Perform forward pass on the seq2seq model. + * Perform forward pass on the seq2seq model (both encoder and decoder). * @param {Object} self The seq2seq model object. * @param {Object} model_inputs The input object for the model containing encoder and decoder inputs. * @param {Object} options The options @@ -355,35 +364,27 @@ async function seq2seqForward(self, model_inputs, { encoder_input_name = 'input_ids', add_decoder_pkv = true } = {}) { - let encoderOutputs = model_inputs.encoder_outputs; - let pastKeyValues = model_inputs.past_key_values; + let { encoder_outputs, past_key_values } = model_inputs; - if (!encoderOutputs) { - const encoderFeeds = { - [encoder_input_name]: model_inputs[encoder_input_name], - } - - if (self.session.inputNames.includes('attention_mask')) { - encoderFeeds.attention_mask = model_inputs.attention_mask - } - const encoderResults = await sessionRun(self.session, encoderFeeds); - encoderOutputs = encoderResults.last_hidden_state; + if (!encoder_outputs) { + // Encoder outputs are not given, so we must compute them. + encoder_outputs = (await encoderForward(self, model_inputs, encoder_input_name)).last_hidden_state; } let decoderFeeds = { input_ids: model_inputs.decoder_input_ids, - encoder_hidden_states: encoderOutputs, - use_cache_branch: boolTensor(pastKeyValues !== null) + encoder_hidden_states: encoder_outputs, + use_cache_branch: boolTensor(past_key_values !== null) }; if (self.decoder_merged_session.inputNames.includes('encoder_attention_mask')) { decoderFeeds.encoder_attention_mask = model_inputs.attention_mask } - self.addPastKeyValues(decoderFeeds, pastKeyValues, add_decoder_pkv); + self.addPastKeyValues(decoderFeeds, past_key_values, add_decoder_pkv); const decoderResults = await sessionRun(self.decoder_merged_session, decoderFeeds); let logits = decoderResults.logits; - pastKeyValues = self.getPastKeyValues(decoderResults, pastKeyValues); - return new Seq2SeqLMOutput(logits, pastKeyValues, encoderOutputs); + past_key_values = self.getPastKeyValues(decoderResults, past_key_values); + return new Seq2SeqLMOutput({ logits, past_key_values, encoder_outputs }); } /** @@ -467,11 +468,31 @@ async function seq2seqRunBeam(self, beam, { return output; } +/** + * Forward pass of an encoder model. + * @param {Object} self The encoder model. + * @param {Object} model_inputs The input data to be used for the forward pass. + * @returns {Promise} Promise that resolves with an object containing the model's outputs. + * @private + */ +async function encoderForward(self, model_inputs, encoder_input_name = 'input_ids') { + const encoderFeeds = { + [encoder_input_name]: model_inputs[encoder_input_name], + } + + if (self.session.inputNames.includes('attention_mask')) { + encoderFeeds.attention_mask = model_inputs.attention_mask; + } + const encoderResults = await sessionRun(self.session, encoderFeeds); + return new BaseModelOutput(encoderResults); +} + + /** * Forward pass of a decoder model. * @param {Object} self The decoder model. * @param {Object} model_inputs The input data to be used for the forward pass. - * @returns {Promise} Promise that resolves with an object containing the logits and past key values. + * @returns {Promise} Promise that resolves with an object containing the logits and past key values. * @private */ async function decoderForward(self, model_inputs) { @@ -487,7 +508,7 @@ async function decoderForward(self, model_inputs) { let logits = decoderResults.logits; past_key_values = self.getPastKeyValues(decoderResults, past_key_values); - return new CausalLMOutputWithPast(logits, past_key_values); + return { logits, past_key_values }; } /** @@ -658,18 +679,18 @@ export class PreTrainedModel extends Callable { * @returns {Promise} Object containing output tensors */ async _call(model_inputs) { - return await call(this, model_inputs); + return await this.forward(model_inputs); } /** - * Forward method should be implemented in subclasses. - * @abstract + * Forward method for a pretrained model. If not overridden by a subclass, the correct forward method + * will be chosen based on the model type. * @param {Object} model_inputs The input data to the model in the format specified in the ONNX model. * @returns {Promise} The output data from the model in the format specified in the ONNX model. * @throws {Error} This method must be implemented in subclasses. */ async forward(model_inputs) { - throw Error("forward should be implemented in subclasses.") + return await forward(this, model_inputs); } /** @@ -1060,7 +1081,23 @@ export class PreTrainedModel extends Callable { // Base model output class export class ModelOutput { } - +/** + * Base class for model's outputs, with potential hidden states and attentions. + */ +export class BaseModelOutput extends ModelOutput { + /** + * @param {Object} output The output of the model. + * @param {Tensor} output.last_hidden_state Sequence of hidden-states at the output of the last layer of the model. + * @param {Tensor} [output.hidden_states] Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + * @param {Tensor} [output.attentions] Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + */ + constructor({ last_hidden_state, hidden_states = null, attentions = null }) { + super(); + this.last_hidden_state = last_hidden_state; + this.hidden_states = hidden_states; + this.attentions = attentions; + } +} ////////////////////////////////////////////////// // Bert models export class BertPreTrainedModel extends PreTrainedModel { } @@ -1078,8 +1115,7 @@ export class BertForMaskedLM extends BertPreTrainedModel { * @returns {Promise} An object containing the model's output logits for masked language modeling. */ async _call(model_inputs) { - let { logits } = await super._call(model_inputs); - return new MaskedLMOutput(logits); + return new MaskedLMOutput(await super._call(model_inputs)); } } @@ -1095,8 +1131,7 @@ export class BertForSequenceClassification extends BertPreTrainedModel { * @returns {Promise} An object containing the model's output logits for sequence classification. */ async _call(model_inputs) { - let { logits } = await super._call(model_inputs); - return new SequenceClassifierOutput(logits); + return new SequenceClassifierOutput(await super._call(model_inputs)); } } @@ -1112,8 +1147,7 @@ export class BertForTokenClassification extends BertPreTrainedModel { * @returns {Promise} An object containing the model's output logits for token classification. */ async _call(model_inputs) { - let { logits } = await super._call(model_inputs); - return new TokenClassifierOutput(logits); + return new TokenClassifierOutput(await super._call(model_inputs)); } } @@ -1129,8 +1163,7 @@ export class BertForQuestionAnswering extends BertPreTrainedModel { * @returns {Promise} An object containing the model's output logits for question answering. */ async _call(model_inputs) { - let { start_logits, end_logits } = await super._call(model_inputs); - return new QuestionAnsweringModelOutput(start_logits, end_logits); + return new QuestionAnsweringModelOutput(await super._call(model_inputs)); } } ////////////////////////////////////////////////// @@ -1152,8 +1185,7 @@ export class DistilBertForSequenceClassification extends DistilBertPreTrainedMod * @returns {Promise} An object containing the model's output logits for sequence classification. */ async _call(model_inputs) { - let { logits } = await super._call(model_inputs); - return new SequenceClassifierOutput(logits); + return new SequenceClassifierOutput(await super._call(model_inputs)); } } @@ -1169,8 +1201,7 @@ export class DistilBertForTokenClassification extends DistilBertPreTrainedModel * @returns {Promise} An object containing the model's output logits for token classification. */ async _call(model_inputs) { - let { logits } = await super._call(model_inputs); - return new TokenClassifierOutput(logits); + return new TokenClassifierOutput(await super._call(model_inputs)); } } @@ -1187,8 +1218,7 @@ export class DistilBertForQuestionAnswering extends DistilBertPreTrainedModel { * @returns {Promise} An object containing the model's output logits for question answering. */ async _call(model_inputs) { - let { start_logits, end_logits } = await super._call(model_inputs); - return new QuestionAnsweringModelOutput(start_logits, end_logits); + return new QuestionAnsweringModelOutput(await super._call(model_inputs)); } } @@ -1204,8 +1234,7 @@ export class DistilBertForMaskedLM extends DistilBertPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let { logits } = await super._call(model_inputs); - return new MaskedLMOutput(logits); + return new MaskedLMOutput(await super._call(model_inputs)); } } ////////////////////////////////////////////////// @@ -1228,8 +1257,7 @@ export class MobileBertForMaskedLM extends MobileBertPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let { logits } = await super._call(model_inputs); - return new MaskedLMOutput(logits); + return new MaskedLMOutput(await super._call(model_inputs)); } } @@ -1244,8 +1272,7 @@ export class MobileBertForSequenceClassification extends MobileBertPreTrainedMod * @returns {Promise} returned object */ async _call(model_inputs) { - let { logits } = await super._call(model_inputs); - return new SequenceClassifierOutput(logits); + return new SequenceClassifierOutput(await super._call(model_inputs)); } } @@ -1260,8 +1287,7 @@ export class MobileBertForQuestionAnswering extends MobileBertPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let { start_logits, end_logits } = await super._call(model_inputs); - return new QuestionAnsweringModelOutput(start_logits, end_logits); + return new QuestionAnsweringModelOutput(await super._call(model_inputs)); } } ////////////////////////////////////////////////// @@ -1279,8 +1305,7 @@ export class SqueezeBertForMaskedLM extends SqueezeBertPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let { logits } = await super._call(model_inputs); - return new MaskedLMOutput(logits); + return new MaskedLMOutput(await super._call(model_inputs)); } } export class SqueezeBertForSequenceClassification extends SqueezeBertPreTrainedModel { @@ -1291,8 +1316,7 @@ export class SqueezeBertForSequenceClassification extends SqueezeBertPreTrainedM * @returns {Promise} returned object */ async _call(model_inputs) { - let { logits } = await super._call(model_inputs); - return new SequenceClassifierOutput(logits); + return new SequenceClassifierOutput(await super._call(model_inputs)); } } export class SqueezeBertForQuestionAnswering extends SqueezeBertPreTrainedModel { @@ -1303,8 +1327,7 @@ export class SqueezeBertForQuestionAnswering extends SqueezeBertPreTrainedModel * @returns {Promise} returned object */ async _call(model_inputs) { - let { start_logits, end_logits } = await super._call(model_inputs); - return new QuestionAnsweringModelOutput(start_logits, end_logits); + return new QuestionAnsweringModelOutput(await super._call(model_inputs)); } } ////////////////////////////////////////////////// @@ -1322,8 +1345,7 @@ export class AlbertForSequenceClassification extends AlbertPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let { logits } = await super._call(model_inputs); - return new SequenceClassifierOutput(logits); + return new SequenceClassifierOutput(await super._call(model_inputs)); } } export class AlbertForQuestionAnswering extends AlbertPreTrainedModel { @@ -1334,8 +1356,7 @@ export class AlbertForQuestionAnswering extends AlbertPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let { start_logits, end_logits } = await super._call(model_inputs); - return new QuestionAnsweringModelOutput(start_logits, end_logits); + return new QuestionAnsweringModelOutput(await super._call(model_inputs)); } } export class AlbertForMaskedLM extends AlbertPreTrainedModel { @@ -1346,8 +1367,7 @@ export class AlbertForMaskedLM extends AlbertPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let { logits } = await super._call(model_inputs); - return new MaskedLMOutput(logits); + return new MaskedLMOutput(await super._call(model_inputs)); } } ////////////////////////////////////////////////// @@ -1625,8 +1645,7 @@ export class BartForSequenceClassification extends BartPretrainedModel { * @returns {Promise} An object containing the model's output logits for sequence classification. */ async _call(model_inputs) { - let { logits } = await super._call(model_inputs); - return new SequenceClassifierOutput(logits); + return new SequenceClassifierOutput(await super._call(model_inputs)); } } @@ -1649,8 +1668,7 @@ export class RobertaForMaskedLM extends RobertaPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let { logits } = await super._call(model_inputs); - return new MaskedLMOutput(logits); + return new MaskedLMOutput(await super._call(model_inputs)); } } @@ -1666,8 +1684,7 @@ export class RobertaForSequenceClassification extends RobertaPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let { logits } = await super._call(model_inputs); - return new SequenceClassifierOutput(logits); + return new SequenceClassifierOutput(await super._call(model_inputs)); } } @@ -1683,8 +1700,7 @@ export class RobertaForQuestionAnswering extends RobertaPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let { start_logits, end_logits } = await super._call(model_inputs); - return new QuestionAnsweringModelOutput(start_logits, end_logits); + return new QuestionAnsweringModelOutput(await super._call(model_inputs)); } } ////////////////////////////////////////////////// @@ -1970,7 +1986,7 @@ export class GPT2LMHeadModel extends GPT2PreTrainedModel { * @returns {Promise} The output tensor of the model. */ async forward(model_inputs) { - return await decoderForward(this, model_inputs) + return await decoderForward(this, model_inputs); } } @@ -2047,7 +2063,7 @@ export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { * @returns {Promise} The output tensor of the model. */ async forward(model_inputs) { - return await decoderForward(this, model_inputs) + return await decoderForward(this, model_inputs); } } @@ -2134,7 +2150,7 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { * @returns {Promise} The output tensor of the model. */ async forward(model_inputs) { - return await decoderForward(this, model_inputs) + return await decoderForward(this, model_inputs); } } @@ -2147,8 +2163,7 @@ export class ViTForImageClassification extends ViTPreTrainedModel { * @param {any} model_inputs */ async _call(model_inputs) { - let { logits } = await super._call(model_inputs); - return new SequenceClassifierOutput(logits); + return new SequenceClassifierOutput(await super._call(model_inputs)); } } ////////////////////////////////////////////////// @@ -2160,8 +2175,7 @@ export class DetrForObjectDetection extends DetrPreTrainedModel { * @param {any} model_inputs */ async _call(model_inputs) { - let { logits, pred_boxes } = await super._call(model_inputs); - return new DetrObjectDetectionOutput(logits, pred_boxes); + return new DetrObjectDetectionOutput(await super._call(model_inputs)); } } @@ -2172,17 +2186,18 @@ export class DetrForSegmentation extends DetrPreTrainedModel { * @returns {Promise} Object containing segmentation outputs */ async _call(model_inputs) { - let { logits, pred_boxes, pred_masks } = await super._call(model_inputs); - return new DetrSegmentationOutput(logits, pred_boxes, pred_masks); + return new DetrSegmentationOutput(await super._call(model_inputs)); } } export class DetrObjectDetectionOutput extends ModelOutput { /** - * @param {Tensor} logits - * @param {Tensor} pred_boxes + * @param {Object} output The output of the model. + * @param {Tensor} output.logits Classification logits (including no-object) for all queries. + * @param {Tensor} output.pred_boxes Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). + * These values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding possible padding). */ - constructor(logits, pred_boxes) { + constructor({ logits, pred_boxes }) { super(); this.logits = logits; this.pred_boxes = pred_boxes; @@ -2190,13 +2205,13 @@ export class DetrObjectDetectionOutput extends ModelOutput { } export class DetrSegmentationOutput extends ModelOutput { - /** - * @param {Tensor} logits The output logits of the model. - * @param {Tensor} pred_boxes Predicted boxes. - * @param {Tensor} pred_masks Predicted masks. + * @param {Object} output The output of the model. + * @param {Tensor} output.logits The output logits of the model. + * @param {Tensor} output.pred_boxes Predicted boxes. + * @param {Tensor} output.pred_masks Predicted masks. */ - constructor(logits, pred_boxes, pred_masks) { + constructor({ logits, pred_boxes, pred_masks }) { super(); this.logits = logits; this.pred_boxes = pred_boxes; @@ -2216,24 +2231,21 @@ export class SamModel extends SamPreTrainedModel { * @todo Add support for `input_labels`, `input_boxes`, `input_masks`, and `image_embeddings`. */ async _call(model_inputs) { - // TODO split into encoder and decoder - let { iou_scores, pred_masks } = await super._call(model_inputs); - return new SamImageSegmentationOutput(iou_scores, pred_masks); + return new SamImageSegmentationOutput(await super._call(model_inputs)); } } /** * Base class for Segment-Anything model's output. - * - * @extends ModelOutput */ export class SamImageSegmentationOutput extends ModelOutput { /** - * @param {Tensor} iou_scores The output logits of the model. - * @param {Tensor} pred_masks Predicted boxes. + * @param {Object} output The output of the model. + * @param {Tensor} output.iou_scores The output logits of the model. + * @param {Tensor} output.pred_masks Predicted boxes. */ - constructor(iou_scores, pred_masks) { + constructor({ iou_scores, pred_masks }) { super(); this.iou_scores = iou_scores; this.pred_masks = pred_masks; @@ -2688,11 +2700,12 @@ export class AutoModelForMaskGeneration extends PretrainedMixin { ////////////////////////////////////////////////// export class Seq2SeqLMOutput extends ModelOutput { /** - * @param {Tensor} logits The output logits of the model. - * @param {Tensor} past_key_values An tensor of key/value pairs that represent the previous state of the model. - * @param {Tensor} encoder_outputs The output of the encoder in a sequence-to-sequence model. + * @param {Object} output The output of the model. + * @param {Tensor} output.logits The output logits of the model. + * @param {Tensor} output.past_key_values An tensor of key/value pairs that represent the previous state of the model. + * @param {Tensor} output.encoder_outputs The output of the encoder in a sequence-to-sequence model. */ - constructor(logits, past_key_values, encoder_outputs) { + constructor({ logits, past_key_values, encoder_outputs }) { super(); this.logits = logits; this.past_key_values = past_key_values; @@ -2700,43 +2713,58 @@ export class Seq2SeqLMOutput extends ModelOutput { } } +/** + * Base class for outputs of sentence classification models. + */ export class SequenceClassifierOutput extends ModelOutput { /** - * @param {Tensor} logits + * @param {Object} output The output of the model. + * @param {Tensor} output.logits classification (or regression if config.num_labels==1) scores (before SoftMax). */ - constructor(logits) { + constructor({ logits }) { super(); this.logits = logits; } } +/** + * Base class for outputs of token classification models. + */ export class TokenClassifierOutput extends ModelOutput { /** - * @param {Tensor} logits + * @param {Object} output The output of the model. + * @param {Tensor} output.logits Classification scores (before SoftMax). */ - constructor(logits) { + constructor({ logits }) { super(); this.logits = logits; } } - +/** + * Base class for masked language models outputs. + */ export class MaskedLMOutput extends ModelOutput { /** - * @param {Tensor} logits + * @param {Object} output The output of the model. + * @param {Tensor} output.logits Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). */ - constructor(logits) { + constructor({ logits }) { super(); this.logits = logits; } } +/** + * Base class for outputs of question answering models. + */ export class QuestionAnsweringModelOutput extends ModelOutput { /** - * @param {Tensor} start_logits The logits for start positions of the answer. - * @param {Tensor} end_logits The logits for end positions of the answer. + * @param {Object} output The output of the model. + * @param {Tensor} output.start_logits Span-start scores (before SoftMax). + * @param {Tensor} output.end_logits Span-end scores (before SoftMax). */ - constructor(start_logits, end_logits) { + constructor({ start_logits, end_logits }) { super(); this.start_logits = start_logits; this.end_logits = end_logits; @@ -2749,11 +2777,12 @@ export class QuestionAnsweringModelOutput extends ModelOutput { */ export class CausalLMOutputWithPast extends ModelOutput { /** - * @param {Tensor} logits Prediction scores of the language modeling head (scores for each vocabulary token before softmax). - * @param {Tensor} past_key_values Contains pre-computed hidden-states (key and values in the self-attention blocks) + * @param {Object} output The output of the model. + * @param {Tensor} output.logits Prediction scores of the language modeling head (scores for each vocabulary token before softmax). + * @param {Tensor} output.past_key_values Contains pre-computed hidden-states (key and values in the self-attention blocks) * that can be used (see `past_key_values` input) to speed up sequential decoding. */ - constructor(logits, past_key_values) { + constructor({ logits, past_key_values }) { super(); this.logits = logits; this.past_key_values = past_key_values; From 99aae03a8be4a581d917923299beb989c4a343f9 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 15 Jun 2023 18:22:00 +0200 Subject: [PATCH 12/22] Remove `encoder_input_name` from seq2seq forward method --- src/models.js | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/src/models.js b/src/models.js index b3c955235..c48b201b6 100644 --- a/src/models.js +++ b/src/models.js @@ -355,20 +355,18 @@ function boolTensor(value) { * @param {Object} self The seq2seq model object. * @param {Object} model_inputs The input object for the model containing encoder and decoder inputs. * @param {Object} options The options - * @param {string} [options.encoder_input_name='input_ids'] The name of the input tensor for the encoder. * @param {boolean} [options.add_decoder_pkv=true] Flag to add the decoder past key values. * @returns {Promise} Promise that resolves with the output of the seq2seq model. * @private */ async function seq2seqForward(self, model_inputs, { - encoder_input_name = 'input_ids', add_decoder_pkv = true } = {}) { let { encoder_outputs, past_key_values } = model_inputs; if (!encoder_outputs) { // Encoder outputs are not given, so we must compute them. - encoder_outputs = (await encoderForward(self, model_inputs, encoder_input_name)).last_hidden_state; + encoder_outputs = (await encoderForward(self, model_inputs)).last_hidden_state; } let decoderFeeds = { input_ids: model_inputs.decoder_input_ids, @@ -472,19 +470,15 @@ async function seq2seqRunBeam(self, beam, { * Forward pass of an encoder model. * @param {Object} self The encoder model. * @param {Object} model_inputs The input data to be used for the forward pass. - * @returns {Promise} Promise that resolves with an object containing the model's outputs. + * @returns {Promise} Promise that resolves with an object containing the model's outputs. * @private */ -async function encoderForward(self, model_inputs, encoder_input_name = 'input_ids') { - const encoderFeeds = { - [encoder_input_name]: model_inputs[encoder_input_name], - } - +async function encoderForward(self, model_inputs) { + const encoderFeeds = { ...model_inputs }; // Shallow copy if (self.session.inputNames.includes('attention_mask')) { encoderFeeds.attention_mask = model_inputs.attention_mask; } - const encoderResults = await sessionRun(self.session, encoderFeeds); - return new BaseModelOutput(encoderResults); + return await sessionRun(self.session, encoderFeeds);; } @@ -1822,9 +1816,7 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { * @returns {Promise} The model output. */ async forward(model_inputs) { - return await seq2seqForward(this, model_inputs, { - encoder_input_name: 'input_features', - }); + return await seq2seqForward(this, model_inputs); } } ////////////////////////////////////////////////// @@ -1891,7 +1883,6 @@ export class VisionEncoderDecoderModel extends PreTrainedModel { */ async forward(model_inputs) { return await seq2seqForward(this, model_inputs, { - encoder_input_name: 'pixel_values', add_decoder_pkv: false }) } From f6d1578abd41d6f0496d9fc6063d11caf1bfb29b Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 15 Jun 2023 21:35:59 +0200 Subject: [PATCH 13/22] Extract `validateInputs` helper function from `sessionRun` --- src/models.js | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/src/models.js b/src/models.js index c48b201b6..3b9cec8c7 100644 --- a/src/models.js +++ b/src/models.js @@ -205,19 +205,14 @@ async function constructSession(pretrained_model_name_or_path, fileName, options } /** - * Executes an InferenceSession using the specified inputs. - * NOTE: `inputs` must contain at least the input names of the model. - * - If additional inputs are passed, they will be ignored. - * - If inputs are missing, an error will be thrown. - * - * @param {InferenceSession} session The InferenceSession object to run. - * @param {Object} inputs An object that maps input names to input tensors. - * @returns {Promise} A Promise that resolves to an object that maps output names to output tensors. + * Validate model inputs + * @param {InferenceSession} session The InferenceSession object that will be run. + * @param {Object} inputs The inputs to check. + * @returns {Promise} A Promise that resolves to the checked inputs. + * @throws {Error} If any inputs are missing. * @private */ -async function sessionRun(session, inputs) { - - // First, check that all inputs are provided +async function validateInputs(session, inputs) { // NOTE: Only create a shallow copy const checkedInputs = {}; const missingInputs = []; @@ -242,6 +237,22 @@ async function sessionRun(session, inputs) { console.warn(`WARNING: Too many inputs were provided (${numInputsProvided} > ${numInputsNeeded}). The following inputs will be ignored: "${ignored.join(', ')}".`); } + return checkedInputs; +} + +/** + * Executes an InferenceSession using the specified inputs. + * NOTE: `inputs` must contain at least the input names of the model. + * - If additional inputs are passed, they will be ignored. + * - If inputs are missing, an error will be thrown. + * + * @param {InferenceSession} session The InferenceSession object to run. + * @param {Object} inputs An object that maps input names to input tensors. + * @returns {Promise} A Promise that resolves to an object that maps output names to output tensors. + * @private + */ +async function sessionRun(session, inputs) { + const checkedInputs = await validateInputs(session, inputs); try { let output = await session.run(checkedInputs); output = replaceTensors(output); @@ -478,7 +489,7 @@ async function encoderForward(self, model_inputs) { if (self.session.inputNames.includes('attention_mask')) { encoderFeeds.attention_mask = model_inputs.attention_mask; } - return await sessionRun(self.session, encoderFeeds);; + return await sessionRun(self.session, encoderFeeds); } From ae34b74b03045a799e6dbb1c788bf5851d58ef12 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Thu, 15 Jun 2023 21:36:36 +0200 Subject: [PATCH 14/22] Move `compare` helper function to separate utility file --- tests/pipelines.test.js | 41 +-------------------------------------- tests/test_utils.js | 43 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 40 deletions(-) create mode 100644 tests/test_utils.js diff --git a/tests/pipelines.test.js b/tests/pipelines.test.js index 26d5100f6..5d988ef83 100644 --- a/tests/pipelines.test.js +++ b/tests/pipelines.test.js @@ -1,50 +1,11 @@ import { pipeline, cos_sim } from '../src/transformers.js'; import { init, m, MAX_TEST_EXECUTION_TIME } from './init.js'; - +import { compare } from './test_utils.js'; // Initialise the testing environment init(); - -function compare(val1, val2, tol = 0.1) { - if ( - (val1 !== null && val2 !== null) && - (typeof val1 === 'object' && typeof val2 === 'object') - ) { - // Both are non-null objects - - if (Array.isArray(val1) && Array.isArray(val2)) { - expect(val1).toHaveLength(val2.length); - - for (let i = 0; i < val1.length; ++i) { - compare(val1[i], val2[i], tol); - } - - } else { - expect(Object.keys(val1)).toHaveLength(Object.keys(val2).length); - - for (let key in val1) { - compare(val1[key], val2[key]); - } - } - - } else { - // At least one of them is not an object - // First check that both have the same type - expect(typeof val1).toEqual(typeof val2); - - if (typeof val1 === 'number' && (!Number.isInteger(val1) || !Number.isInteger(val2))) { - // If both are numbers and at least one of them is not an integer - expect(val1).toBeCloseTo(val2, tol); - } else { - // Perform equality test - expect(val1).toEqual(val2); - } - } -} - - // NOTE: // Due to a memory leak in Jest, we cannot have multiple tests for a single model. // This is due to how model construction and destruction occurs, in `beforeAll` and `afterAll`, respectively. diff --git a/tests/test_utils.js b/tests/test_utils.js new file mode 100644 index 000000000..22d71947e --- /dev/null +++ b/tests/test_utils.js @@ -0,0 +1,43 @@ + +/** + * Deep equality test (for arrays and objects) with tolerance for floating point numbers + * @param {any} val1 The first item + * @param {any} val2 The second item + * @param {number} tol Tolerance for floating point numbers + */ +export function compare(val1, val2, tol = 0.1) { + if ( + (val1 !== null && val2 !== null) && + (typeof val1 === 'object' && typeof val2 === 'object') + ) { + // Both are non-null objects + + if (Array.isArray(val1) && Array.isArray(val2)) { + expect(val1).toHaveLength(val2.length); + + for (let i = 0; i < val1.length; ++i) { + compare(val1[i], val2[i], tol); + } + + } else { + expect(Object.keys(val1)).toHaveLength(Object.keys(val2).length); + + for (let key in val1) { + compare(val1[key], val2[key]); + } + } + + } else { + // At least one of them is not an object + // First check that both have the same type + expect(typeof val1).toEqual(typeof val2); + + if (typeof val1 === 'number' && (!Number.isInteger(val1) || !Number.isInteger(val2))) { + // If both are numbers and at least one of them is not an integer + expect(val1).toBeCloseTo(val2, tol); + } else { + // Perform equality test + expect(val1).toEqual(val2); + } + } +} \ No newline at end of file From 63b9b267ee2b6dc2d695a87e9999adcb8d6b2e9e Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 16 Jun 2023 01:44:24 +0200 Subject: [PATCH 15/22] Default `model_type` to null --- src/configs.js | 1 + 1 file changed, 1 insertion(+) diff --git a/src/configs.js b/src/configs.js index 8eb1b48e1..4506d2d9c 100644 --- a/src/configs.js +++ b/src/configs.js @@ -59,6 +59,7 @@ export class PretrainedConfig { * @param {Object} configJSON The JSON of the config. */ constructor(configJSON) { + this.model_type = null; this.is_encoder_decoder = false; Object.assign(this, configJSON); From 6a8175f5d317b53ce9d888f94dfe8c0c53447def Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 16 Jun 2023 01:45:57 +0200 Subject: [PATCH 16/22] Reduce duplication when loading models using `.from_pretrained` --- src/models.js | 526 +++++++++++++++++++++++++------------------------- 1 file changed, 268 insertions(+), 258 deletions(-) diff --git a/src/models.js b/src/models.js index 3b9cec8c7..957c6d22a 100644 --- a/src/models.js +++ b/src/models.js @@ -84,9 +84,8 @@ const { InferenceSession, Tensor: ONNXTensor } = ONNX; class ModelType { }; // Either encoder-only or encoder-decoder (and will be decided by `model.config.is_encoder_decoder`) -class EncoderModelType extends ModelType { }; -class EncoderOnlyModelType extends EncoderModelType { }; -class EncoderDecoderModelType extends EncoderModelType { }; +class EncoderOnlyModelType extends ModelType { }; +class EncoderDecoderModelType extends ModelType { }; class Seq2SeqModelType extends EncoderDecoderModelType { }; class DecoderOnlyModelType extends ModelType { }; ////////////////////////////////////////////////// @@ -95,66 +94,9 @@ class DecoderOnlyModelType extends ModelType { }; ////////////////////////////////////////////////// // Helper functions -/** - * Determines whether `a` is a subclass of `b`. - * @param {any} a The first item - * @param {any} b The second item - * @returns Whether `a` is a subclass of `b`. - * @private - */ -function issubclass(a, b) { - return a === b || a.prototype instanceof b; -} - -/** - * Loads an model from the specified path. - * @param {Object} cls The class of the model. - * @param {string} pretrained_model_name_or_path The path to the model directory. - * @param {PretrainedOptions} options Additional options for loading the model. - * @returns {Promise} A promise that resolves with information about the loaded model. - * @private - */ -async function load(cls, pretrained_model_name_or_path, options) { - const cModelType = cls.MODEL_TYPE; - if (cls.MODEL_TYPE === null) { - throw new Error("`MODEL_TYPE` not implemented for this model."); - } - if (issubclass(cModelType, DecoderOnlyModelType)) { - return await Promise.all([ - AutoConfig.from_pretrained(pretrained_model_name_or_path, options), - constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), - ]); - - } else if (issubclass(cModelType, Seq2SeqModelType)) { - return await Promise.all([ - AutoConfig.from_pretrained(pretrained_model_name_or_path, options), - constructSession(pretrained_model_name_or_path, 'encoder_model', options), - constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), - getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), - ]); - - } else if (issubclass(cModelType, EncoderDecoderModelType)) { - return await Promise.all([ - AutoConfig.from_pretrained(pretrained_model_name_or_path, options), - constructSession(pretrained_model_name_or_path, 'encoder_model', options), - constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), - ]); - - } else if (issubclass(cModelType, EncoderOnlyModelType)) { - return await Promise.all([ - AutoConfig.from_pretrained(pretrained_model_name_or_path, options), - constructSession(pretrained_model_name_or_path, 'model', options) - ]); - - } else if (issubclass(cModelType, EncoderModelType)) { - let config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options) - let modelName = config.is_encoder_decoder ? 'encoder_model' : 'model'; - let session = await constructSession(pretrained_model_name_or_path, modelName, options); - return [config, session]; - } else { - throw Error(`Unable to determine model type: ${cModelType?.constructor?.name}`); - } -} +// Will be populated later +const MODEL_TYPE_MAPPING = new Map(); +const MODEL_CLASS_MAPPING = new Map(); /** * Helper function to determine which `forward` method to run for a specific model. @@ -163,7 +105,7 @@ async function load(cls, pretrained_model_name_or_path, options) { * @returns {Promise} The model output */ async function forward(self, model_inputs) { - if (issubclass(self.constructor.MODEL_TYPE, DecoderOnlyModelType)) { + if (MODEL_TYPE_MAPPING.get(self.constructor.name) === DecoderOnlyModelType) { return await decoderForward(self, model_inputs); } else { return await encoderForward(self, model_inputs); @@ -507,7 +449,7 @@ async function decoderForward(self, model_inputs) { attention_mask: model_inputs.attention_mask, use_cache_branch: boolTensor(past_key_values !== null) } - self.addPastKeyValues(decoderFeeds, past_key_values) + self.addPastKeyValues(decoderFeeds, past_key_values); let decoderResults = await sessionRun(self.session, decoderFeeds); let logits = decoderResults.logits; @@ -617,7 +559,6 @@ function decoderUpdatebeam(beam, newTokenId) { * @extends Callable */ export class PreTrainedModel extends Callable { - static MODEL_TYPE = EncoderModelType; /** * Creates a new instance of the `PreTrainedModel` class. @@ -634,11 +575,9 @@ export class PreTrainedModel extends Callable { /** * Disposes of all the ONNX sessions that were created during inference. * @returns {Promise} An array of promises, one for each ONNX session that is being disposed. + * @todo Use https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry */ async dispose() { - // Dispose of all ONNX sessions sessions - // TODO use: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry - let promises = []; for (let key of Object.keys(this)) { let item = this[key]; @@ -650,10 +589,17 @@ export class PreTrainedModel extends Callable { } /** - * Loads a pre-trained model from the given `pretrained_model_name_or_path`. + * Instantiate one of the model classes of the library from a pretrained model. * - * @param {string} pretrained_model_name_or_path The path to the pre-trained model. - * @param {PretrainedOptions} options Additional options for loading the model. For more information, @see {@link PreTrainedModel.from_pretrained}. + * The model class to instantiate is selected based on the `model_type` property of the config object + * (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible) + * + * @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either: + * - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a + * user or organization name, like `dbmdz/bert-base-german-cased`. + * - A path to a *directory* containing model weights, e.g., `./my_model_directory/`. + * @param {PretrainedOptions} options Additional options for loading the model. * * @returns {Promise} A new instance of the `PreTrainedModel` class. */ @@ -665,14 +611,50 @@ export class PreTrainedModel extends Callable { local_files_only = false, revision = 'main', } = {}) { - let info = await load(this, pretrained_model_name_or_path, { + + let options = { quantized, progress_callback, config, cache_dir, local_files_only, revision, - }); + } + + let modelType = MODEL_TYPE_MAPPING.get(this.name); + + let info; + if (modelType === DecoderOnlyModelType) { + info = await Promise.all([ + AutoConfig.from_pretrained(pretrained_model_name_or_path, options), + constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), + ]); + + } else if (modelType === Seq2SeqModelType) { + info = await Promise.all([ + AutoConfig.from_pretrained(pretrained_model_name_or_path, options), + constructSession(pretrained_model_name_or_path, 'encoder_model', options), + constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), + getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), + ]); + + } else if (modelType === EncoderDecoderModelType) { + info = await Promise.all([ + AutoConfig.from_pretrained(pretrained_model_name_or_path, options), + constructSession(pretrained_model_name_or_path, 'encoder_model', options), + constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), + ]); + + } else if (modelType === EncoderOnlyModelType) { + info = await Promise.all([ + AutoConfig.from_pretrained(pretrained_model_name_or_path, options), + constructSession(pretrained_model_name_or_path, 'model', options) + ]); + + } else { + console.warn('Malformed class definition.', this); + throw Error(`Unable to load model: ${pretrained_model_name_or_path}. Please report this bug at https://github.com/xenova/transformers.js/issues/new/choose.`); + } // @ts-ignore return new this(...info); @@ -1401,7 +1383,6 @@ export class T5Model extends T5PreTrainedModel { * @extends T5PreTrainedModel */ export class T5ForConditionalGeneration extends T5PreTrainedModel { - static MODEL_TYPE = Seq2SeqModelType; /** * Creates a new instance of the `T5ForConditionalGeneration` class. @@ -1487,7 +1468,6 @@ export class MT5Model extends MT5PreTrainedModel { * @extends MT5PreTrainedModel */ export class MT5ForConditionalGeneration extends MT5PreTrainedModel { - static MODEL_TYPE = Seq2SeqModelType; /** * Creates a new instance of the `MT5ForConditionalGeneration` class. @@ -1580,7 +1560,6 @@ export class BartModel extends BartPretrainedModel { * @extends BartPretrainedModel */ export class BartForConditionalGeneration extends BartPretrainedModel { - static MODEL_TYPE = Seq2SeqModelType; /** * Creates a new instance of the `BartForConditionalGeneration` class. @@ -1737,7 +1716,6 @@ export class WhisperModel extends WhisperPreTrainedModel { * @extends WhisperPreTrainedModel */ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { - static MODEL_TYPE = Seq2SeqModelType; /** * Creates a new instance of the `WhisperForConditionalGeneration` class. @@ -1911,14 +1889,28 @@ export class CLIPModel extends CLIPPreTrainedModel { ////////////////////////////////////////////////// // GPT2 models -export class GPT2PreTrainedModel extends PreTrainedModel { } -/** - * GPT2Model is not compatible with `.generate()`, as it doesn't have a language model head. - * @extends GPT2PreTrainedModel - */ +export class GPT2PreTrainedModel extends PreTrainedModel { + /** + * Creates a new instance of the `GPT2PreTrainedModel` class. + * @param {Object} config The configuration of the model. + * @param {any} session The ONNX session containing the model weights. + */ + constructor(config, session) { + super(config, session); + + // config doesn't contain pad_token_id, so we assume it is the eos_token_id + this.config.pad_token_id = this.config.eos_token_id + + this.num_heads = this.config.n_head + this.num_layers = this.config.n_layer + this.dim_kv = this.config.n_embd / this.num_heads; + } +} + export class GPT2Model extends GPT2PreTrainedModel { + /** - * + * GPT2Model is not compatible with `.generate()`, as it doesn't have a language model head. * @param {...any} args * @throws {Error} * @returns {Promise} @@ -1935,23 +1927,6 @@ export class GPT2Model extends GPT2PreTrainedModel { * @extends GPT2PreTrainedModel */ export class GPT2LMHeadModel extends GPT2PreTrainedModel { - static MODEL_TYPE = DecoderOnlyModelType; - - /** - * Creates a new instance of the `GPT2LMHeadModel` class. - * @param {Object} config The configuration of the model. - * @param {any} session The ONNX session containing the model weights. - */ - constructor(config, session) { - super(config, session); - - // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id - - this.num_heads = this.config.n_head - this.num_layers = this.config.n_layer - this.dim_kv = this.config.n_embd / this.num_heads; - } /** * Initializes and returns the beam for text generation task @@ -1996,7 +1971,23 @@ export class GPT2LMHeadModel extends GPT2PreTrainedModel { // TODO // } ////////////////////////////////////////////////// -export class GPTNeoPreTrainedModel extends PreTrainedModel { } +export class GPTNeoPreTrainedModel extends PreTrainedModel { + /** + * Creates a new instance of the `GPTNeoPreTrainedModel` class. + * @param {Object} config The configuration of the model. + * @param {any} session The ONNX session containing the model weights. + */ + constructor(config, session) { + super(config, session); + + // config doesn't contain pad_token_id, so we assume it is the eos_token_id + this.config.pad_token_id = this.config.eos_token_id + + this.num_heads = this.config.num_heads; + this.num_layers = this.config.num_layers; + this.dim_kv = this.config.hidden_size / this.num_heads; + } +} export class GPTNeoModel extends GPTNeoPreTrainedModel { /** * @@ -2012,23 +2003,6 @@ export class GPTNeoModel extends GPTNeoPreTrainedModel { } export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { - static MODEL_TYPE = DecoderOnlyModelType; - - /** - * Creates a new instance of the `GPTNeoForCausalLM` class. - * @param {Object} config The configuration of the model. - * @param {any} session The ONNX session containing the model weights. - */ - constructor(config, session) { - super(config, session); - - // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id - - this.num_heads = this.config.num_heads; - this.num_layers = this.config.num_layers; - this.dim_kv = this.config.hidden_size / this.num_heads; - } /** * Initializes and returns the beam for text generation task @@ -2071,7 +2045,23 @@ export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { ////////////////////////////////////////////////// // CodeGen models -export class CodeGenPreTrainedModel extends PreTrainedModel { } +export class CodeGenPreTrainedModel extends PreTrainedModel { + /** + * Creates a new instance of the `CodeGenPreTrainedModel` class. + * @param {Object} config The model configuration object. + * @param {Object} session The ONNX session object. + */ + constructor(config, session) { + super(config, session); + + // config doesn't contain pad_token_id, so we assume it is the eos_token_id + this.config.pad_token_id = this.config.eos_token_id + + this.num_heads = this.config.n_head + this.num_layers = this.config.n_layer + this.dim_kv = this.config.n_embd / this.num_heads; + } +} /** * CodeGenModel is a class representing a code generation model without a language model head. * @@ -2099,23 +2089,6 @@ export class CodeGenModel extends CodeGenPreTrainedModel { * @extends CodeGenPreTrainedModel */ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { - static MODEL_TYPE = DecoderOnlyModelType; - - /** - * Creates a new instance of the `CodeGenForCausalLM` class. - * @param {Object} config The model configuration object. - * @param {Object} session The ONNX session object. - */ - constructor(config, session) { - super(config, session); - - // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id - - this.num_heads = this.config.n_head - this.num_layers = this.config.n_layer - this.dim_kv = this.config.n_embd / this.num_heads; - } /** * Initializes and returns the beam for text generation task @@ -2275,7 +2248,6 @@ export class MarianModel extends MarianPreTrainedModel { } export class MarianMTModel extends MarianPreTrainedModel { - static MODEL_TYPE = Seq2SeqModelType; /** * Creates a new instance of the `MarianMTModel` class. @@ -2355,7 +2327,6 @@ export class M2M100Model extends M2M100PreTrainedModel { } export class M2M100ForConditionalGeneration extends M2M100PreTrainedModel { - static MODEL_TYPE = Seq2SeqModelType; /** * Creates a new instance of the `M2M100ForConditionalGeneration` class. @@ -2429,8 +2400,9 @@ export class M2M100ForConditionalGeneration extends M2M100PreTrainedModel { export class PretrainedMixin { /** * Mapping from model type to model class. + * @type {Map[]} */ - static MODEL_CLASS_MAPPING = Object.create(null); + static MODEL_CLASS_MAPPINGS = null; /** * Whether to attempt to instantiate the base class (`PretrainedModel`) if @@ -2438,26 +2410,8 @@ export class PretrainedMixin { */ static BASE_IF_FAIL = false; - /** - * The type of model. We set this to null since we require each `AutoModel` to override this. - */ - static MODEL_TYPE = null; - /** - * Instantiate one of the model classes of the library from a pretrained model. - * - * The model class to instantiate is selected based on the `model_type` property of the config object - * (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible) - * - * @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either: - * - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a - * user or organization name, like `dbmdz/bert-base-german-cased`. - * - A path to a *directory* containing model weights, e.g., `./my_model_directory/`. - * @param {PretrainedOptions} options Additional options for loading the model. - * - * @returns {Promise} A new instance of the `PreTrainedModel` class. - */ + /** @type {PreTrainedModel.from_pretrained} */ static async from_pretrained(pretrained_model_name_or_path, { quantized = true, progress_callback = null, @@ -2466,25 +2420,158 @@ export class PretrainedMixin { local_files_only = false, revision = 'main', } = {}) { - let info = await load(this, pretrained_model_name_or_path, { + + let options = { quantized, progress_callback, config, cache_dir, local_files_only, revision, - }); + } + config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options); - let cls = this.MODEL_CLASS_MAPPING[info[0].model_type]; - if (!cls) { - if (this.BASE_IF_FAIL) { - console.warn(`Unknown model class "${info[0].model_type}", attempting to construct from base class.`); - cls = PreTrainedModel; - } else { - throw Error(`Unsupported model type: ${info[0].model_type}`) + if (!this.MODEL_CLASS_MAPPINGS) { + throw new Error("`MODEL_CLASS_MAPPINGS` not implemented for this type of `AutoClass`: " + this.name); + } + + let modelClass; + for (let MODEL_CLASS_MAPPING of this.MODEL_CLASS_MAPPINGS) { + modelClass = MODEL_CLASS_MAPPING.get(config.model_type); + if (!modelClass) { + continue; // Item not found in this mapping } + + return await modelClass.from_pretrained(pretrained_model_name_or_path, options); + } + + if (this.BASE_IF_FAIL) { + console.warn(`Unknown model class "${config.model_type}", attempting to construct from base class.`); + return await PreTrainedModel.from_pretrained(pretrained_model_name_or_path, options); + } else { + throw Error(`Unsupported model type: ${config.model_type}`) } - return new cls(...info); + } +} + +const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([ + ['bert', BertModel], + ['albert', AlbertModel], + ['distilbert', DistilBertModel], + ['roberta', RobertaModel], + ['clip', CLIPModel], + ['mobilebert', MobileBertModel], + ['squeezebert', SqueezeBertModel], + + ['sam', SamModel], // TODO change to encoder-decoder when model is split correctly +]); + +const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([ + ['t5', T5Model], + ['mt5', MT5Model], + ['bart', BartModel], + ['marian', MarianModel], + ['whisper', WhisperModel], + ['m2m_100', M2M100Model], +]); + + +const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([ + ['gpt2', GPT2Model], + ['gpt_neo', GPTNeoModel], + ['codegen', CodeGenModel], +]); + +const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([ + ['bert', BertForSequenceClassification], + ['albert', AlbertForSequenceClassification], + ['distilbert', DistilBertForSequenceClassification], + ['roberta', RobertaForSequenceClassification], + ['bart', BartForSequenceClassification], + ['mobilebert', MobileBertForSequenceClassification], + ['squeezebert', SqueezeBertForSequenceClassification], +]); + +const MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = new Map([ + ['bert', BertForTokenClassification], + ['distilbert', DistilBertForTokenClassification], +]); + +const MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES = new Map([ + ['t5', T5ForConditionalGeneration], + ['mt5', MT5ForConditionalGeneration], + ['bart', BartForConditionalGeneration], + ['whisper', WhisperForConditionalGeneration], + ['marian', MarianMTModel], + ['m2m_100', M2M100ForConditionalGeneration], +]); + +const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([ + ['gpt2', GPT2LMHeadModel], + ['gpt_neo', GPTNeoForCausalLM], + ['codegen', CodeGenForCausalLM], +]); + +const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([ + ['bert', BertForMaskedLM], + ['albert', AlbertForMaskedLM], + ['distilbert', DistilBertForMaskedLM], + ['roberta', RobertaForMaskedLM], + ['mobilebert', MobileBertForMaskedLM], + ['squeezebert', SqueezeBertForMaskedLM], +]); + +const MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = new Map([ + ['bert', BertForQuestionAnswering], + ['albert', AlbertForQuestionAnswering], + ['distilbert', DistilBertForQuestionAnswering], + ['roberta', RobertaForQuestionAnswering], + ['mobilebert', MobileBertForQuestionAnswering], + ['squeezebert', SqueezeBertForQuestionAnswering], +]); + +const MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = new Map([ + ['vision-encoder-decoder', VisionEncoderDecoderModel], +]); + +const MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = new Map([ + ['vit', ViTForImageClassification], +]); + +const MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = new Map([ + ['detr', DetrForObjectDetection], +]); + +const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([ + ['detr', DetrForSegmentation], +]); + +const MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = new Map([ + ['sam', SamModel], +]); + +const MODEL_CLASS_TYPE_MAPPING = [ + [MODEL_MAPPING_NAMES_ENCODER_ONLY, EncoderOnlyModelType], + [MODEL_MAPPING_NAMES_ENCODER_DECODER, EncoderDecoderModelType], + [MODEL_MAPPING_NAMES_DECODER_ONLY, DecoderOnlyModelType], + [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, EncoderOnlyModelType], + [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, EncoderOnlyModelType], + [MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES, Seq2SeqModelType], + [MODEL_WITH_LM_HEAD_MAPPING_NAMES, DecoderOnlyModelType], + [MODEL_FOR_MASKED_LM_MAPPING_NAMES, EncoderOnlyModelType], + [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, EncoderOnlyModelType], + [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, EncoderDecoderModelType], + [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, EncoderOnlyModelType], + [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, EncoderOnlyModelType], + [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, EncoderOnlyModelType], + [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, EncoderOnlyModelType], +]; + +for (let [mappings, type] of MODEL_CLASS_TYPE_MAPPING) { + // @ts-ignore + for (let [name, model] of mappings.entries()) { + MODEL_TYPE_MAPPING.set(model.name, type); + MODEL_CLASS_MAPPING.set(model.name, name); } } @@ -2496,27 +2583,8 @@ export class PretrainedMixin { * let model = await AutoModel.from_pretrained('bert-base-uncased'); */ export class AutoModel extends PretrainedMixin { - static MODEL_TYPE = EncoderModelType; // Assume to be an encoder-only or encoder-decoder model. + static MODEL_CLASS_MAPPINGS = [MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_MAPPING_NAMES_ENCODER_DECODER, MODEL_MAPPING_NAMES_DECODER_ONLY]; static BASE_IF_FAIL = true; - static MODEL_CLASS_MAPPING = { - 'bert': BertModel, - 'albert': AlbertModel, - 'distilbert': DistilBertModel, - 't5': T5Model, - 'mt5': MT5Model, - 'gpt2': GPT2Model, - 'gpt_neo': GPTNeoModel, - 'codegen': CodeGenModel, - 'bart': BartModel, - 'roberta': RobertaModel, - 'whisper': WhisperModel, - 'clip': CLIPModel, - 'mobilebert': MobileBertModel, - 'squeezebert': SqueezeBertModel, - 'marian': MarianModel, - 'm2m_100': M2M100Model, - 'sam': SamModel, - } } /** @@ -2527,16 +2595,7 @@ export class AutoModel extends PretrainedMixin { * let model = await AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased-finetuned-sst-2-english'); */ export class AutoModelForSequenceClassification extends PretrainedMixin { - static MODEL_TYPE = EncoderOnlyModelType; - static MODEL_CLASS_MAPPING = { - 'bert': BertForSequenceClassification, - 'albert': AlbertForSequenceClassification, - 'distilbert': DistilBertForSequenceClassification, - 'roberta': RobertaForSequenceClassification, - 'bart': BartForSequenceClassification, - 'mobilebert': MobileBertForSequenceClassification, - 'squeezebert': SqueezeBertForSequenceClassification, - } + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES]; } /** @@ -2547,14 +2606,9 @@ export class AutoModelForSequenceClassification extends PretrainedMixin { * let model = await AutoModelForTokenClassification.from_pretrained('Davlan/distilbert-base-multilingual-cased-ner-hrl'); */ export class AutoModelForTokenClassification extends PretrainedMixin { - static MODEL_TYPE = EncoderOnlyModelType; - static MODEL_CLASS_MAPPING = { - 'bert': BertForTokenClassification, - 'distilbert': DistilBertForTokenClassification, - } + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES]; } - /** * Helper class which is used to instantiate pretrained sequence-to-sequence models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. @@ -2563,15 +2617,7 @@ export class AutoModelForTokenClassification extends PretrainedMixin { * let model = await AutoModelForSeq2SeqLM.from_pretrained('t5-small'); */ export class AutoModelForSeq2SeqLM extends PretrainedMixin { - static MODEL_TYPE = Seq2SeqModelType; - static MODEL_CLASS_MAPPING = { - 't5': T5ForConditionalGeneration, - 'mt5': MT5ForConditionalGeneration, - 'bart': BartForConditionalGeneration, - 'whisper': WhisperForConditionalGeneration, - 'marian': MarianMTModel, - 'm2m_100': M2M100ForConditionalGeneration, - } + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES]; } /** @@ -2582,12 +2628,7 @@ export class AutoModelForSeq2SeqLM extends PretrainedMixin { * let model = await AutoModelForCausalLM.from_pretrained('gpt2'); */ export class AutoModelForCausalLM extends PretrainedMixin { - static MODEL_TYPE = DecoderOnlyModelType; - static MODEL_CLASS_MAPPING = { - 'gpt2': GPT2LMHeadModel, - 'gpt_neo': GPTNeoForCausalLM, - 'codegen': CodeGenForCausalLM, - } + static MODEL_CLASS_MAPPINGS = [MODEL_WITH_LM_HEAD_MAPPING_NAMES]; } /** @@ -2598,15 +2639,7 @@ export class AutoModelForCausalLM extends PretrainedMixin { * let model = await AutoModelForMaskedLM.from_pretrained('bert-base-uncased'); */ export class AutoModelForMaskedLM extends PretrainedMixin { - static MODEL_TYPE = EncoderOnlyModelType; - static MODEL_CLASS_MAPPING = { - 'bert': BertForMaskedLM, - 'albert': AlbertForMaskedLM, - 'distilbert': DistilBertForMaskedLM, - 'roberta': RobertaForMaskedLM, - 'mobilebert': MobileBertForMaskedLM, - 'squeezebert': SqueezeBertForMaskedLM, - } + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_MASKED_LM_MAPPING_NAMES]; } /** @@ -2617,15 +2650,7 @@ export class AutoModelForMaskedLM extends PretrainedMixin { * let model = await AutoModelForQuestionAnswering.from_pretrained('distilbert-base-cased-distilled-squad'); */ export class AutoModelForQuestionAnswering extends PretrainedMixin { - static MODEL_TYPE = EncoderOnlyModelType; - static MODEL_CLASS_MAPPING = { - 'bert': BertForQuestionAnswering, - 'albert': AlbertForQuestionAnswering, - 'distilbert': DistilBertForQuestionAnswering, - 'roberta': RobertaForQuestionAnswering, - 'mobilebert': MobileBertForQuestionAnswering, - 'squeezebert': SqueezeBertForQuestionAnswering, - } + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES]; } /** @@ -2636,10 +2661,7 @@ export class AutoModelForQuestionAnswering extends PretrainedMixin { * let model = await AutoModelForVision2Seq.from_pretrained('nlpconnect/vit-gpt2-image-captioning'); */ export class AutoModelForVision2Seq extends PretrainedMixin { - static MODEL_TYPE = EncoderDecoderModelType; - static MODEL_CLASS_MAPPING = { - 'vision-encoder-decoder': VisionEncoderDecoderModel - } + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES]; } /** @@ -2650,10 +2672,7 @@ export class AutoModelForVision2Seq extends PretrainedMixin { * let model = await AutoModelForImageClassification.from_pretrained('google/vit-base-patch16-224'); */ export class AutoModelForImageClassification extends PretrainedMixin { - static MODEL_TYPE = EncoderOnlyModelType; - static MODEL_CLASS_MAPPING = { - 'vit': ViTForImageClassification, - } + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES]; } /** @@ -2664,10 +2683,7 @@ export class AutoModelForImageClassification extends PretrainedMixin { * let model = await AutoModelForImageSegmentation.from_pretrained('facebook/detr-resnet-50-panoptic'); */ export class AutoModelForImageSegmentation extends PretrainedMixin { - static MODEL_TYPE = EncoderOnlyModelType; - static MODEL_CLASS_MAPPING = { - 'detr': DetrForSegmentation, - } + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES]; } /** @@ -2678,10 +2694,7 @@ export class AutoModelForImageSegmentation extends PretrainedMixin { * let model = await AutoModelForObjectDetection.from_pretrained('facebook/detr-resnet-50'); */ export class AutoModelForObjectDetection extends PretrainedMixin { - static MODEL_TYPE = EncoderOnlyModelType; - static MODEL_CLASS_MAPPING = { - 'detr': DetrForObjectDetection, - } + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES]; } /** @@ -2692,10 +2705,7 @@ export class AutoModelForObjectDetection extends PretrainedMixin { * let model = await AutoModelForMaskGeneration.from_pretrained('Xenova/sam-vit-base'); */ export class AutoModelForMaskGeneration extends PretrainedMixin { - static MODEL_TYPE = EncoderOnlyModelType; - static MODEL_CLASS_MAPPING = { - 'sam': SamModel, - } + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES]; } ////////////////////////////////////////////////// From ee098bad7e1bc7176e0299f3668a0e75def2d046 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 16 Jun 2023 01:46:29 +0200 Subject: [PATCH 17/22] Add unit tests for loading models using `.from_pretrained()` --- tests/models.test.js | 90 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 tests/models.test.js diff --git a/tests/models.test.js b/tests/models.test.js new file mode 100644 index 000000000..b012581de --- /dev/null +++ b/tests/models.test.js @@ -0,0 +1,90 @@ +/* + * Test that models loaded outside of the `pipeline` function work correctly (e.g., `AutoModel.from_pretrained(...)`); + */ + +import { + AutoTokenizer, + AutoModel, + + BertModel, + GPT2Model, + T5Model, + + BertTokenizer, + GPT2Tokenizer, + T5Tokenizer, + +} from '../src/transformers.js'; + +import { init, m, MAX_TEST_EXECUTION_TIME } from './init.js'; + +import { compare } from './test_utils.js'; + +// Initialise the testing environment +init(); + + +describe('Loading models', () => { + + // List all models which will be tested + const models_to_test = [ + // [name, modelClass, tokenizerClass] + ['bert-base-uncased', BertModel, BertTokenizer], // Encoder-only + ['gpt2', GPT2Model, GPT2Tokenizer], // Decoder-only + ['t5-small', T5Model, T5Tokenizer], // Encoder-decoder + ]; + + let texts = [ + 'Once upon a time', + 'I like to eat apples', + ]; + + for (let [name, modelClass, tokenizerClass] of models_to_test) { + + // Test that both the auto model and the specific model work + let tokenizers = [AutoTokenizer, tokenizerClass]; + let models = [AutoModel, modelClass]; + + for (let i = 0; i < tokenizers.length; ++i) { + const tokenizerClassToTest = tokenizers[i]; + const modelClassToTest = models[i]; + + it(`${name} (${modelClassToTest.name})`, async () => { + const model_id = m(name); + + // Load model and tokenizer + let tokenizer = await tokenizerClassToTest.from_pretrained(model_id); + let model = await modelClassToTest.from_pretrained(model_id); + + let tests = [ + texts[0], // single + texts, // batched + ] + for (let test of tests) { + let encodings = await tokenizer(test, { truncation: true, padding: true }); + let output = await model(encodings); + + if (output.logits) { + // Ensure correct shapes + let expected_shape = [...encodings.input_ids.dims, model.config.vocab_size]; + let actual_shape = output.logits.dims; + compare(expected_shape, actual_shape); + } else if (output.last_hidden_state) { + let expected_shape = [...encodings.input_ids.dims, model.config.d_model]; + let actual_shape = output.last_hidden_state.dims; + compare(expected_shape, actual_shape); + } else { + console.warn('Unexpected output', output); + throw new Error('Unexpected output'); + } + + } + + await model.dispose(); + + }, MAX_TEST_EXECUTION_TIME); + + } + } + +}); \ No newline at end of file From e6f60bff2cdd71905cd18e68de5128c3f45abfe8 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 17 Jun 2023 01:48:49 +0200 Subject: [PATCH 18/22] Compute attention mask for decoder if not given --- src/models.js | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/models.js b/src/models.js index 957c6d22a..3a40f59c1 100644 --- a/src/models.js +++ b/src/models.js @@ -446,9 +446,15 @@ async function decoderForward(self, model_inputs) { let past_key_values = model_inputs.past_key_values; let decoderFeeds = { input_ids: model_inputs.input_ids, - attention_mask: model_inputs.attention_mask, use_cache_branch: boolTensor(past_key_values !== null) } + + if (decoderFeeds.attention_mask) { + decoderFeeds.attention_mask = model_inputs.attention_mask; + } else { + decoderFeeds.attention_mask = prepareAttentionMask(self, model_inputs.input_ids); + } + self.addPastKeyValues(decoderFeeds, past_key_values); let decoderResults = await sessionRun(self.session, decoderFeeds); From ec6415d782d310e9c2c501149fcdb5abd0f63eb9 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 17 Jun 2023 01:49:49 +0200 Subject: [PATCH 19/22] Improve decoder attention computation --- src/models.js | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/models.js b/src/models.js index 3a40f59c1..4cafcb7ae 100644 --- a/src/models.js +++ b/src/models.js @@ -446,15 +446,10 @@ async function decoderForward(self, model_inputs) { let past_key_values = model_inputs.past_key_values; let decoderFeeds = { input_ids: model_inputs.input_ids, + attention_mask: model_inputs.attention_mask ?? prepareAttentionMask(self, model_inputs.input_ids), use_cache_branch: boolTensor(past_key_values !== null) } - if (decoderFeeds.attention_mask) { - decoderFeeds.attention_mask = model_inputs.attention_mask; - } else { - decoderFeeds.attention_mask = prepareAttentionMask(self, model_inputs.input_ids); - } - self.addPastKeyValues(decoderFeeds, past_key_values); let decoderResults = await sessionRun(self.session, decoderFeeds); From b2be59864c9f88cf64f982142af2680ac2f1d7e5 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 17 Jun 2023 01:50:21 +0200 Subject: [PATCH 20/22] Implement `flatten` and `view` tensor ops --- src/utils/tensor.js | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/utils/tensor.js b/src/utils/tensor.js index af1f9fa1b..b90f8ee7f 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -440,6 +440,45 @@ export class Tensor extends ONNXTensor { this.dims = calc_unsqueeze_dims(this.dims, dim); return this; } + + flatten_(start_dim = 0, end_dim = -1) { + // TODO validate inputs + end_dim = (end_dim + this.dims.length) % this.dims.length; + + let dimsToKeepBefore = this.dims.slice(0, start_dim); + let dimsToFlatten = this.dims.slice(start_dim, end_dim + 1); + let dimsToKeepAfter = this.dims.slice(end_dim + 1); + + this.dims = [...dimsToKeepBefore, dimsToFlatten.reduce((a, b) => a * b, 1), ...dimsToKeepAfter] + return this; + } + + flatten(start_dim = 0, end_dim = -1) { + return this.clone().flatten_(start_dim, end_dim); + } + + view(...dims) { + // TODO: validate dims + let inferredIndex = -1; + for (let i = 0; i < dims.length; ++i) { + if (dims[i] === -1) { + if (inferredIndex !== -1) { + throw new Error("Only one dimension can be inferred"); + } + inferredIndex = i; + } + } + + if (inferredIndex !== -1) { + // Some dimension must be inferred + const productOther = dims.reduce((product, curr, index) => { + return index !== inferredIndex ? product * curr : product + }, 1); + + dims[inferredIndex] = this.data.length / productOther; + } + return new Tensor(this.type, this.data, dims); // NOTE: uses same underlying storage + } } /** From 6387f3e25f527e57da41f5bbe6fc3877bfdc82ed Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 19 Jun 2023 18:52:03 +0200 Subject: [PATCH 21/22] Add documentation for new tensor ops --- src/utils/tensor.js | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/utils/tensor.js b/src/utils/tensor.js index b90f8ee7f..ff6983e97 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -441,6 +441,9 @@ export class Tensor extends ONNXTensor { return this; } + /** + * In-place version of @see {@link Tensor.flatten} + */ flatten_(start_dim = 0, end_dim = -1) { // TODO validate inputs end_dim = (end_dim + this.dims.length) % this.dims.length; @@ -453,10 +456,23 @@ export class Tensor extends ONNXTensor { return this; } + /** + * Flattens input by reshaping it into a one-dimensional tensor. + * If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` + * and ending with `end_dim` are flattened. The order of elements in input is unchanged. + * @param {*} start_dim the first dim to flatten + * @param {*} end_dim the last dim to flatten + * @returns The flattened tensor. + */ flatten(start_dim = 0, end_dim = -1) { return this.clone().flatten_(start_dim, end_dim); } + /** + * Returns a new tensor with the same data as the `self` tensor but of a different `shape`. + * @param {...number} dims the desired size + * @returns {Tensor} The tensor with the same data but different shape + */ view(...dims) { // TODO: validate dims let inferredIndex = -1; From 66efc091b7091203c62a5ff4787a83ed0149399c Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 19 Jun 2023 18:54:15 +0200 Subject: [PATCH 22/22] Fix `flatten` input types --- src/utils/tensor.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils/tensor.js b/src/utils/tensor.js index ff6983e97..d452287a4 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -460,8 +460,8 @@ export class Tensor extends ONNXTensor { * Flattens input by reshaping it into a one-dimensional tensor. * If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` * and ending with `end_dim` are flattened. The order of elements in input is unchanged. - * @param {*} start_dim the first dim to flatten - * @param {*} end_dim the last dim to flatten + * @param {number} start_dim the first dim to flatten + * @param {number} end_dim the last dim to flatten * @returns The flattened tensor. */ flatten(start_dim = 0, end_dim = -1) {