diff --git a/src/configs.js b/src/configs.js index 8eb1b48e1..4506d2d9c 100644 --- a/src/configs.js +++ b/src/configs.js @@ -59,6 +59,7 @@ export class PretrainedConfig { * @param {Object} configJSON The JSON of the config. */ constructor(configJSON) { + this.model_type = null; this.is_encoder_decoder = false; Object.assign(this, configJSON); diff --git a/src/models.js b/src/models.js index 05410e5d3..4cafcb7ae 100644 --- a/src/models.js +++ b/src/models.js @@ -78,14 +78,47 @@ const { InferenceSession, Tensor: ONNXTensor } = ONNX; * @typedef {import('./utils/hub.js').PretrainedOptions} PretrainedOptions */ + +////////////////////////////////////////////////// +// Model types: used internally +class ModelType { }; + +// Either encoder-only or encoder-decoder (and will be decided by `model.config.is_encoder_decoder`) +class EncoderOnlyModelType extends ModelType { }; +class EncoderDecoderModelType extends ModelType { }; +class Seq2SeqModelType extends EncoderDecoderModelType { }; +class DecoderOnlyModelType extends ModelType { }; +////////////////////////////////////////////////// + + ////////////////////////////////////////////////// // Helper functions + +// Will be populated later +const MODEL_TYPE_MAPPING = new Map(); +const MODEL_CLASS_MAPPING = new Map(); + +/** + * Helper function to determine which `forward` method to run for a specific model. + * @param {Object} self The calling object + * @param {Object} model_inputs The inputs to be sent to the model + * @returns {Promise} The model output + */ +async function forward(self, model_inputs) { + if (MODEL_TYPE_MAPPING.get(self.constructor.name) === DecoderOnlyModelType) { + return await decoderForward(self, model_inputs); + } else { + return await encoderForward(self, model_inputs); + } +} + /** * Constructs an InferenceSession using a model file located at the specified path. * @param {string} pretrained_model_name_or_path The path to the directory containing the model file. * @param {string} fileName The name of the model file. * @param {PretrainedOptions} options Additional options for loading the model. * @returns {Promise} A Promise that resolves to an InferenceSession object. + * @private */ async function constructSession(pretrained_model_name_or_path, fileName, options) { // TODO add option for user to force specify their desired execution provider @@ -114,18 +147,14 @@ async function constructSession(pretrained_model_name_or_path, fileName, options } /** - * Executes an InferenceSession using the specified inputs. - * NOTE: `inputs` must contain at least the input names of the model. - * - If additional inputs are passed, they will be ignored. - * - If inputs are missing, an error will be thrown. - * - * @param {InferenceSession} session The InferenceSession object to run. - * @param {Object} inputs An object that maps input names to input tensors. - * @returns {Promise} A Promise that resolves to an object that maps output names to output tensors. + * Validate model inputs + * @param {InferenceSession} session The InferenceSession object that will be run. + * @param {Object} inputs The inputs to check. + * @returns {Promise} A Promise that resolves to the checked inputs. + * @throws {Error} If any inputs are missing. + * @private */ -async function sessionRun(session, inputs) { - - // First, check that all inputs are provided +async function validateInputs(session, inputs) { // NOTE: Only create a shallow copy const checkedInputs = {}; const missingInputs = []; @@ -150,6 +179,22 @@ async function sessionRun(session, inputs) { console.warn(`WARNING: Too many inputs were provided (${numInputsProvided} > ${numInputsNeeded}). The following inputs will be ignored: "${ignored.join(', ')}".`); } + return checkedInputs; +} + +/** + * Executes an InferenceSession using the specified inputs. + * NOTE: `inputs` must contain at least the input names of the model. + * - If additional inputs are passed, they will be ignored. + * - If inputs are missing, an error will be thrown. + * + * @param {InferenceSession} session The InferenceSession object to run. + * @param {Object} inputs An object that maps input names to input tensors. + * @returns {Promise} A Promise that resolves to an object that maps output names to output tensors. + * @private + */ +async function sessionRun(session, inputs) { + const checkedInputs = await validateInputs(session, inputs); try { let output = await session.run(checkedInputs); output = replaceTensors(output); @@ -166,10 +211,9 @@ async function sessionRun(session, inputs) { * Replaces ONNX Tensor objects with custom Tensor objects to support additional functions. * @param {Object} obj The object to replace tensor objects in. * @returns {Object} The object with tensor objects replaced by custom Tensor objects. + * @private */ function replaceTensors(obj) { - // Convert ONNX Tensors with our custom Tensor class - // to support additional functions for (let prop in obj) { if (obj[prop] instanceof ONNXTensor) { obj[prop] = new Tensor(obj[prop]); @@ -219,8 +263,9 @@ function toI64Tensor(items) { * @param {Object} self The calling object instance. * @param {Tensor} tokens The input tokens. * @returns {Tensor} The attention mask tensor. + * @private */ -function _prepare_attention_mask(self, tokens) { +function prepareAttentionMask(self, tokens) { // Prepare attention mask let pad_token_id = self.config.pad_token_id ?? null; @@ -251,132 +296,46 @@ function _prepare_attention_mask(self, tokens) { * Creates a boolean tensor with a single value. * @param {boolean} value The value of the tensor. * @returns {Tensor} The boolean tensor. + * @private */ function boolTensor(value) { - // Create boolean tensor return new Tensor('bool', [value], [1]); } - -/** - * Loads a model from the specified path. - * @param {string} pretrained_model_name_or_path The path to the model directory. - * @param {PretrainedOptions} options Additional options for loading the model. - * @returns {Promise} A promise that resolves with information about the loaded model. - */ -async function loadAutoModel(pretrained_model_name_or_path, options) { - let config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options) - - let modelName = config.is_encoder_decoder ? 'encoder_model' : 'model'; - - let session = await constructSession(pretrained_model_name_or_path, modelName, options); - - return [config, session]; -} - -/** - * Loads a model from the specified path. - * @param {string} pretrained_model_name_or_path The path to the model directory. - * @param {PretrainedOptions} options Additional options for loading the model. - * @returns {Promise} A promise that resolves with information about the loaded model. - */ -async function loadModel(pretrained_model_name_or_path, options) { - let info = await Promise.all([ - AutoConfig.from_pretrained(pretrained_model_name_or_path, options), - constructSession(pretrained_model_name_or_path, 'model', options) - ]); - return info; -} - -/** - * Loads a sequence-to-sequence model from the specified path. - * @param {string} pretrained_model_name_or_path The path to the model directory. - * @param {PretrainedOptions} options Additional options for loading the model. - * @returns {Promise} A promise that resolves with information about the loaded model. - */ -async function seq2seqLoadModel(pretrained_model_name_or_path, options) { - let info = await Promise.all([ - AutoConfig.from_pretrained(pretrained_model_name_or_path, options), - constructSession(pretrained_model_name_or_path, 'encoder_model', options), - constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), - getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), - ]); - return info; -} - -/** - * Loads an encoder-decoder model from the specified path. - * @param {string} pretrained_model_name_or_path The path to the model directory. - * @param {PretrainedOptions} options Additional options for loading the model. - * @returns {Promise} A promise that resolves with information about the loaded model. - */ -async function encoderDecoderLoadModel(pretrained_model_name_or_path, options) { - let info = await Promise.all([ - AutoConfig.from_pretrained(pretrained_model_name_or_path, options), - constructSession(pretrained_model_name_or_path, 'encoder_model', options), - constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), - ]) - return info; -} - - -/** - * Loads a decoder model from the specified path. - * @param {string} pretrained_model_name_or_path The path to the model directory. - * @param {PretrainedOptions} options Additional options for loading the model. - * @returns {Promise} A promise that resolves with information about the loaded model. - */ -async function decoderLoadModel(pretrained_model_name_or_path, options) { - let info = await Promise.all([ - AutoConfig.from_pretrained(pretrained_model_name_or_path, options), - constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), - ]) - return info; -} - // JS doesn't support mixins, so we define some reused functions here, and allow "this" to be passed in /** - * Perform forward pass on the seq2seq model. + * Perform forward pass on the seq2seq model (both encoder and decoder). * @param {Object} self The seq2seq model object. * @param {Object} model_inputs The input object for the model containing encoder and decoder inputs. * @param {Object} options The options - * @param {string} [options.encoder_input_name='input_ids'] The name of the input tensor for the encoder. * @param {boolean} [options.add_decoder_pkv=true] Flag to add the decoder past key values. * @returns {Promise} Promise that resolves with the output of the seq2seq model. + * @private */ -async function seq2seq_forward(self, model_inputs, { - encoder_input_name = 'input_ids', +async function seq2seqForward(self, model_inputs, { add_decoder_pkv = true } = {}) { - let encoderOutputs = model_inputs.encoder_outputs; - let pastKeyValues = model_inputs.past_key_values; + let { encoder_outputs, past_key_values } = model_inputs; - if (encoderOutputs === null) { - const encoderFeeds = { - [encoder_input_name]: model_inputs[encoder_input_name], - } - - if (self.session.inputNames.includes('attention_mask')) { - encoderFeeds.attention_mask = model_inputs.attention_mask - } - const encoderResults = await sessionRun(self.session, encoderFeeds); - encoderOutputs = encoderResults.last_hidden_state; + if (!encoder_outputs) { + // Encoder outputs are not given, so we must compute them. + encoder_outputs = (await encoderForward(self, model_inputs)).last_hidden_state; } let decoderFeeds = { input_ids: model_inputs.decoder_input_ids, - encoder_hidden_states: encoderOutputs, - use_cache_branch: boolTensor(pastKeyValues !== null) + encoder_hidden_states: encoder_outputs, + use_cache_branch: boolTensor(past_key_values !== null) }; if (self.decoder_merged_session.inputNames.includes('encoder_attention_mask')) { decoderFeeds.encoder_attention_mask = model_inputs.attention_mask } - self.addPastKeyValues(decoderFeeds, pastKeyValues, add_decoder_pkv); + self.addPastKeyValues(decoderFeeds, past_key_values, add_decoder_pkv); const decoderResults = await sessionRun(self.decoder_merged_session, decoderFeeds); let logits = decoderResults.logits; - pastKeyValues = self.getPastKeyValues(decoderResults, pastKeyValues); - return new Seq2SeqLMOutput(logits, pastKeyValues, encoderOutputs); + past_key_values = self.getPastKeyValues(decoderResults, past_key_values); + return new Seq2SeqLMOutput({ logits, past_key_values, encoder_outputs }); } /** @@ -386,6 +345,7 @@ async function seq2seq_forward(self, model_inputs, { * @param {number} numOutputTokens The maximum number of output tokens for the model. * @param {boolean} [requires_attention_mask=true] Flag to indicate if the model requires an attention mask. * @returns {Object[]} Array of beam search objects. + * @private */ function seq2seqStartBeams(self, inputTokenIds, numOutputTokens, requires_attention_mask = true) { let beams = []; @@ -416,7 +376,7 @@ function seq2seqStartBeams(self, inputTokenIds, numOutputTokens, requires_attent } if (requires_attention_mask) { - start.attention_mask = _prepare_attention_mask(self, tokens); + start.attention_mask = prepareAttentionMask(self, tokens); } beams.push(start); @@ -432,6 +392,7 @@ function seq2seqStartBeams(self, inputTokenIds, numOutputTokens, requires_attent * @param {Object} options options * @param {string} [options.input_name='input_ids'] The name of the input tensor for the encoder. * @returns {Promise} Promise that resolves with the output of the seq2seq model for the given beam. + * @private */ async function seq2seqRunBeam(self, beam, { input_name = 'input_ids', @@ -459,19 +420,37 @@ async function seq2seqRunBeam(self, beam, { } /** - * Forward pass of the text generation model. - * @param {Object} self The text generation model object. + * Forward pass of an encoder model. + * @param {Object} self The encoder model. + * @param {Object} model_inputs The input data to be used for the forward pass. + * @returns {Promise} Promise that resolves with an object containing the model's outputs. + * @private + */ +async function encoderForward(self, model_inputs) { + const encoderFeeds = { ...model_inputs }; // Shallow copy + if (self.session.inputNames.includes('attention_mask')) { + encoderFeeds.attention_mask = model_inputs.attention_mask; + } + return await sessionRun(self.session, encoderFeeds); +} + + +/** + * Forward pass of a decoder model. + * @param {Object} self The decoder model. * @param {Object} model_inputs The input data to be used for the forward pass. * @returns {Promise} Promise that resolves with an object containing the logits and past key values. + * @private */ -async function textgen_forward(self, model_inputs) { +async function decoderForward(self, model_inputs) { let past_key_values = model_inputs.past_key_values; let decoderFeeds = { input_ids: model_inputs.input_ids, - attention_mask: model_inputs.attention_mask, + attention_mask: model_inputs.attention_mask ?? prepareAttentionMask(self, model_inputs.input_ids), use_cache_branch: boolTensor(past_key_values !== null) } - self.addPastKeyValues(decoderFeeds, past_key_values) + + self.addPastKeyValues(decoderFeeds, past_key_values); let decoderResults = await sessionRun(self.session, decoderFeeds); let logits = decoderResults.logits; @@ -487,8 +466,9 @@ async function textgen_forward(self, model_inputs) { * @param {number} numOutputTokens The maximum number of tokens to generate for each beam. * @param {Tensor} [inputs_attention_mask] The attention mask tensor for the input token IDs. * @returns {Object[]} An array of beams initialized with the given inputs and parameters. + * @private */ -function textgenStartBeams(self, inputTokenIds, numOutputTokens, inputs_attention_mask) { +function decoderStartBeams(self, inputTokenIds, numOutputTokens, inputs_attention_mask) { let beams = []; let beamId = 0; @@ -504,7 +484,7 @@ function textgenStartBeams(self, inputTokenIds, numOutputTokens, inputs_attentio attn_mask.dims = [1, ...attn_mask.dims] } else { - attn_mask = _prepare_attention_mask(self, tokens) + attn_mask = prepareAttentionMask(self, tokens) } let start = { @@ -529,7 +509,7 @@ function textgenStartBeams(self, inputTokenIds, numOutputTokens, inputs_attentio /** * Runs a single step of the text generation process for a given beam. * - * @param {Object} self The textgen object. + * @param {Object} self The decoder object. * @param {Object} beam The beam to run. * @param {Tensor} beam.input The input tensor. * @param {Tensor} beam.model_input_ids The input ids to the model. @@ -537,8 +517,9 @@ function textgenStartBeams(self, inputTokenIds, numOutputTokens, inputs_attentio * @param {Object} beam.past_key_values The past key values. * @param {number[]} beam.output_token_ids The output token ids. * @returns {Promise} The output of the generation step. + * @private */ -async function textgenRunBeam(self, beam) { +async function decoderRunBeam(self, beam) { let attnMaskData = new BigInt64Array(beam.input.data.length + beam.output_token_ids.length).fill(1n) // 1. Prepare @@ -565,8 +546,9 @@ async function textgenRunBeam(self, beam) { * Update a beam with a new token ID. * @param {Object} beam The beam to update. * @param {number} newTokenId The new token ID to add to the beam's output. + * @private */ -function textgenUpdatebeam(beam, newTokenId) { +function decoderUpdatebeam(beam, newTokenId) { beam.output_token_ids = [...beam.output_token_ids, newTokenId]; beam.model_input_ids = new Tensor('int64', [BigInt(newTokenId)], [1, 1]); } @@ -578,6 +560,7 @@ function textgenUpdatebeam(beam, newTokenId) { * @extends Callable */ export class PreTrainedModel extends Callable { + /** * Creates a new instance of the `PreTrainedModel` class. * @param {Object} config The model configuration. @@ -593,11 +576,9 @@ export class PreTrainedModel extends Callable { /** * Disposes of all the ONNX sessions that were created during inference. * @returns {Promise} An array of promises, one for each ONNX session that is being disposed. + * @todo Use https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry */ async dispose() { - // Dispose of all ONNX sessions sessions - // TODO use: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/FinalizationRegistry - let promises = []; for (let key of Object.keys(this)) { let item = this[key]; @@ -609,10 +590,17 @@ export class PreTrainedModel extends Callable { } /** - * Loads a pre-trained model from the given `pretrained_model_name_or_path`. + * Instantiate one of the model classes of the library from a pretrained model. + * + * The model class to instantiate is selected based on the `model_type` property of the config object + * (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible) * - * @param {string} pretrained_model_name_or_path The path to the pre-trained model. - * @param {PretrainedOptions} options Additional options for loading the model. For more information, @see {@link PreTrainedModel.from_pretrained}. + * @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either: + * - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a + * user or organization name, like `dbmdz/bert-base-german-cased`. + * - A path to a *directory* containing model weights, e.g., `./my_model_directory/`. + * @param {PretrainedOptions} options Additional options for loading the model. * * @returns {Promise} A new instance of the `PreTrainedModel` class. */ @@ -624,14 +612,50 @@ export class PreTrainedModel extends Callable { local_files_only = false, revision = 'main', } = {}) { - let info = await loadAutoModel(pretrained_model_name_or_path, { + + let options = { quantized, progress_callback, config, cache_dir, local_files_only, revision, - }); + } + + let modelType = MODEL_TYPE_MAPPING.get(this.name); + + let info; + if (modelType === DecoderOnlyModelType) { + info = await Promise.all([ + AutoConfig.from_pretrained(pretrained_model_name_or_path, options), + constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), + ]); + + } else if (modelType === Seq2SeqModelType) { + info = await Promise.all([ + AutoConfig.from_pretrained(pretrained_model_name_or_path, options), + constructSession(pretrained_model_name_or_path, 'encoder_model', options), + constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), + getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options), + ]); + + } else if (modelType === EncoderDecoderModelType) { + info = await Promise.all([ + AutoConfig.from_pretrained(pretrained_model_name_or_path, options), + constructSession(pretrained_model_name_or_path, 'encoder_model', options), + constructSession(pretrained_model_name_or_path, 'decoder_model_merged', options), + ]); + + } else if (modelType === EncoderOnlyModelType) { + info = await Promise.all([ + AutoConfig.from_pretrained(pretrained_model_name_or_path, options), + constructSession(pretrained_model_name_or_path, 'model', options) + ]); + + } else { + console.warn('Malformed class definition.', this); + throw Error(`Unable to load model: ${pretrained_model_name_or_path}. Please report this bug at https://github.com/xenova/transformers.js/issues/new/choose.`); + } // @ts-ignore return new this(...info); @@ -643,18 +667,18 @@ export class PreTrainedModel extends Callable { * @returns {Promise} Object containing output tensors */ async _call(model_inputs) { - return await sessionRun(this.session, model_inputs); + return await this.forward(model_inputs); } /** - * Forward method should be implemented in subclasses. - * @abstract + * Forward method for a pretrained model. If not overridden by a subclass, the correct forward method + * will be chosen based on the model type. * @param {Object} model_inputs The input data to the model in the format specified in the ONNX model. * @returns {Promise} The output data from the model in the format specified in the ONNX model. * @throws {Error} This method must be implemented in subclasses. */ async forward(model_inputs) { - throw Error("forward should be implemented in subclasses.") + return await forward(this, model_inputs); } /** @@ -1008,7 +1032,9 @@ export class PreTrainedModel extends Callable { * @param {boolean} [hasDecoder=false] Whether the model has a decoder. */ addPastKeyValues(decoderFeeds, pastKeyValues, hasDecoder = false) { - if (pastKeyValues === null) { + if (pastKeyValues) { + Object.assign(decoderFeeds, pastKeyValues) + } else { // TODO support batches (i.e., batch_size > 1) if (hasDecoder) { // @ts-ignore @@ -1036,9 +1062,6 @@ export class PreTrainedModel extends Callable { decoderFeeds[`past_key_values.${i}.value`] = new Tensor('float32', [], dims) } } - - } else { - Object.assign(decoderFeeds, pastKeyValues) } } } @@ -1046,7 +1069,23 @@ export class PreTrainedModel extends Callable { // Base model output class export class ModelOutput { } - +/** + * Base class for model's outputs, with potential hidden states and attentions. + */ +export class BaseModelOutput extends ModelOutput { + /** + * @param {Object} output The output of the model. + * @param {Tensor} output.last_hidden_state Sequence of hidden-states at the output of the last layer of the model. + * @param {Tensor} [output.hidden_states] Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + * @param {Tensor} [output.attentions] Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + */ + constructor({ last_hidden_state, hidden_states = null, attentions = null }) { + super(); + this.last_hidden_state = last_hidden_state; + this.hidden_states = hidden_states; + this.attentions = attentions; + } +} ////////////////////////////////////////////////// // Bert models export class BertPreTrainedModel extends PreTrainedModel { } @@ -1064,8 +1103,7 @@ export class BertForMaskedLM extends BertPreTrainedModel { * @returns {Promise} An object containing the model's output logits for masked language modeling. */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new MaskedLMOutput(logits) + return new MaskedLMOutput(await super._call(model_inputs)); } } @@ -1081,8 +1119,7 @@ export class BertForSequenceClassification extends BertPreTrainedModel { * @returns {Promise} An object containing the model's output logits for sequence classification. */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new SequenceClassifierOutput(logits) + return new SequenceClassifierOutput(await super._call(model_inputs)); } } @@ -1098,8 +1135,7 @@ export class BertForTokenClassification extends BertPreTrainedModel { * @returns {Promise} An object containing the model's output logits for token classification. */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new TokenClassifierOutput(logits) + return new TokenClassifierOutput(await super._call(model_inputs)); } } @@ -1115,8 +1151,7 @@ export class BertForQuestionAnswering extends BertPreTrainedModel { * @returns {Promise} An object containing the model's output logits for question answering. */ async _call(model_inputs) { - let outputs = await super._call(model_inputs); - return new QuestionAnsweringModelOutput(outputs.start_logits, outputs.end_logits); + return new QuestionAnsweringModelOutput(await super._call(model_inputs)); } } ////////////////////////////////////////////////// @@ -1138,8 +1173,7 @@ export class DistilBertForSequenceClassification extends DistilBertPreTrainedMod * @returns {Promise} An object containing the model's output logits for sequence classification. */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new SequenceClassifierOutput(logits) + return new SequenceClassifierOutput(await super._call(model_inputs)); } } @@ -1155,8 +1189,7 @@ export class DistilBertForTokenClassification extends DistilBertPreTrainedModel * @returns {Promise} An object containing the model's output logits for token classification. */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new TokenClassifierOutput(logits) + return new TokenClassifierOutput(await super._call(model_inputs)); } } @@ -1173,8 +1206,7 @@ export class DistilBertForQuestionAnswering extends DistilBertPreTrainedModel { * @returns {Promise} An object containing the model's output logits for question answering. */ async _call(model_inputs) { - let outputs = await super._call(model_inputs); - return new QuestionAnsweringModelOutput(outputs.start_logits, outputs.end_logits); + return new QuestionAnsweringModelOutput(await super._call(model_inputs)); } } @@ -1190,8 +1222,7 @@ export class DistilBertForMaskedLM extends DistilBertPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new MaskedLMOutput(logits) + return new MaskedLMOutput(await super._call(model_inputs)); } } ////////////////////////////////////////////////// @@ -1214,8 +1245,7 @@ export class MobileBertForMaskedLM extends MobileBertPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new MaskedLMOutput(logits) + return new MaskedLMOutput(await super._call(model_inputs)); } } @@ -1230,8 +1260,7 @@ export class MobileBertForSequenceClassification extends MobileBertPreTrainedMod * @returns {Promise} returned object */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new SequenceClassifierOutput(logits) + return new SequenceClassifierOutput(await super._call(model_inputs)); } } @@ -1246,8 +1275,7 @@ export class MobileBertForQuestionAnswering extends MobileBertPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let outputs = await super._call(model_inputs); - return new QuestionAnsweringModelOutput(outputs.start_logits, outputs.end_logits); + return new QuestionAnsweringModelOutput(await super._call(model_inputs)); } } ////////////////////////////////////////////////// @@ -1265,8 +1293,7 @@ export class SqueezeBertForMaskedLM extends SqueezeBertPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new MaskedLMOutput(logits) + return new MaskedLMOutput(await super._call(model_inputs)); } } export class SqueezeBertForSequenceClassification extends SqueezeBertPreTrainedModel { @@ -1277,8 +1304,7 @@ export class SqueezeBertForSequenceClassification extends SqueezeBertPreTrainedM * @returns {Promise} returned object */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new SequenceClassifierOutput(logits) + return new SequenceClassifierOutput(await super._call(model_inputs)); } } export class SqueezeBertForQuestionAnswering extends SqueezeBertPreTrainedModel { @@ -1289,8 +1315,7 @@ export class SqueezeBertForQuestionAnswering extends SqueezeBertPreTrainedModel * @returns {Promise} returned object */ async _call(model_inputs) { - let outputs = await super._call(model_inputs); - return new QuestionAnsweringModelOutput(outputs.start_logits, outputs.end_logits); + return new QuestionAnsweringModelOutput(await super._call(model_inputs)); } } ////////////////////////////////////////////////// @@ -1308,8 +1333,7 @@ export class AlbertForSequenceClassification extends AlbertPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new SequenceClassifierOutput(logits) + return new SequenceClassifierOutput(await super._call(model_inputs)); } } export class AlbertForQuestionAnswering extends AlbertPreTrainedModel { @@ -1320,8 +1344,7 @@ export class AlbertForQuestionAnswering extends AlbertPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let outputs = await super._call(model_inputs); - return new QuestionAnsweringModelOutput(outputs.start_logits, outputs.end_logits); + return new QuestionAnsweringModelOutput(await super._call(model_inputs)); } } export class AlbertForMaskedLM extends AlbertPreTrainedModel { @@ -1332,8 +1355,7 @@ export class AlbertForMaskedLM extends AlbertPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new MaskedLMOutput(logits) + return new MaskedLMOutput(await super._call(model_inputs)); } } ////////////////////////////////////////////////// @@ -1362,6 +1384,7 @@ export class T5Model extends T5PreTrainedModel { * @extends T5PreTrainedModel */ export class T5ForConditionalGeneration extends T5PreTrainedModel { + /** * Creates a new instance of the `T5ForConditionalGeneration` class. * @param {Object} config The model configuration. @@ -1383,34 +1406,6 @@ export class T5ForConditionalGeneration extends T5PreTrainedModel { this.encoder_dim_kv = this.config.d_kv; } - /** - * Loads a pre-trained model from the given `pretrained_model_name_or_path`. - * - * @param {string} pretrained_model_name_or_path The path to the pre-trained model. - * @param {PretrainedOptions} options Additional options for loading the model. For more information, @see {@link PreTrainedModel.from_pretrained}. - * - * @returns {Promise} A new instance of the `T5ForConditionalGeneration` class. - */ - static async from_pretrained(pretrained_model_name_or_path, { - quantized = true, - progress_callback = null, - config = null, - cache_dir = null, - local_files_only = false, - revision = 'main', - } = {}) { - let info = await seq2seqLoadModel(pretrained_model_name_or_path, { - quantized, - progress_callback, - config, - cache_dir, - local_files_only, - revision, - }); - // @ts-ignore - return new this(...info); - } - /** * Generates the start beams for a given set of inputs and output length. * @param {number[][]} inputs The input token IDs. @@ -1445,7 +1440,7 @@ export class T5ForConditionalGeneration extends T5PreTrainedModel { * @returns {Promise} The model output. */ async forward(model_inputs) { - return await seq2seq_forward(this, model_inputs); + return await seq2seqForward(this, model_inputs); } } ////////////////////////////////////////////////// @@ -1474,6 +1469,7 @@ export class MT5Model extends MT5PreTrainedModel { * @extends MT5PreTrainedModel */ export class MT5ForConditionalGeneration extends MT5PreTrainedModel { + /** * Creates a new instance of the `MT5ForConditionalGeneration` class. * @param {any} config The model configuration. @@ -1495,34 +1491,6 @@ export class MT5ForConditionalGeneration extends MT5PreTrainedModel { this.encoder_dim_kv = this.config.d_kv; } - /** - * Loads a pre-trained model from the given `pretrained_model_name_or_path`. - * - * @param {string} pretrained_model_name_or_path The path to the pre-trained model. - * @param {PretrainedOptions} options Additional options for loading the model. For more information, @see {@link PreTrainedModel.from_pretrained}. - * - * @returns {Promise} A new instance of the `MT5ForConditionalGeneration` class. - */ - static async from_pretrained(pretrained_model_name_or_path, { - quantized = true, - progress_callback = null, - config = null, - cache_dir = null, - local_files_only = false, - revision = 'main', - } = {}) { - let info = await seq2seqLoadModel(pretrained_model_name_or_path, { - quantized, - progress_callback, - config, - cache_dir, - local_files_only, - revision, - }); - // @ts-ignore - return new this(...info); - } - /** * Generates the start beams for the given input tokens and output sequence length. * @@ -1559,7 +1527,7 @@ export class MT5ForConditionalGeneration extends MT5PreTrainedModel { * @returns {Promise} A Promise that resolves to the model outputs. */ async forward(model_inputs) { - return await seq2seq_forward(this, model_inputs); + return await seq2seqForward(this, model_inputs); } } ////////////////////////////////////////////////// @@ -1593,6 +1561,7 @@ export class BartModel extends BartPretrainedModel { * @extends BartPretrainedModel */ export class BartForConditionalGeneration extends BartPretrainedModel { + /** * Creates a new instance of the `BartForConditionalGeneration` class. * @param {Object} config The configuration object for the Bart model. @@ -1613,33 +1582,6 @@ export class BartForConditionalGeneration extends BartPretrainedModel { this.num_encoder_heads = this.config.encoder_attention_heads; this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads; } - /** - * Loads a pre-trained model from the given `pretrained_model_name_or_path`. - * - * @param {string} pretrained_model_name_or_path The path to the pre-trained model. - * @param {PretrainedOptions} options Additional options for loading the model. For more information, @see {@link PreTrainedModel.from_pretrained}. - * - * @returns {Promise} A new instance of the `BartForConditionalGeneration` class. - */ - static async from_pretrained(pretrained_model_name_or_path, { - quantized = true, - progress_callback = null, - config = null, - cache_dir = null, - local_files_only = false, - revision = 'main', - } = {}) { - let info = await seq2seqLoadModel(pretrained_model_name_or_path, { - quantized, - progress_callback, - config, - cache_dir, - local_files_only, - revision, - }); - // @ts-ignore - return new this(...info); - } /** * Returns the initial beam for generating output text. @@ -1676,7 +1618,7 @@ export class BartForConditionalGeneration extends BartPretrainedModel { * @returns {Promise} The model output. */ async forward(model_inputs) { - return await seq2seq_forward(this, model_inputs); + return await seq2seqForward(this, model_inputs); } } @@ -1688,8 +1630,7 @@ export class BartForSequenceClassification extends BartPretrainedModel { * @returns {Promise} An object containing the model's output logits for sequence classification. */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new SequenceClassifierOutput(logits) + return new SequenceClassifierOutput(await super._call(model_inputs)); } } @@ -1712,8 +1653,7 @@ export class RobertaForMaskedLM extends RobertaPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new MaskedLMOutput(logits) + return new MaskedLMOutput(await super._call(model_inputs)); } } @@ -1729,8 +1669,7 @@ export class RobertaForSequenceClassification extends RobertaPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new SequenceClassifierOutput(logits) + return new SequenceClassifierOutput(await super._call(model_inputs)); } } @@ -1746,8 +1685,7 @@ export class RobertaForQuestionAnswering extends RobertaPreTrainedModel { * @returns {Promise} returned object */ async _call(model_inputs) { - let outputs = await super._call(model_inputs); - return new QuestionAnsweringModelOutput(outputs.start_logits, outputs.end_logits); + return new QuestionAnsweringModelOutput(await super._call(model_inputs)); } } ////////////////////////////////////////////////// @@ -1779,6 +1717,7 @@ export class WhisperModel extends WhisperPreTrainedModel { * @extends WhisperPreTrainedModel */ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { + /** * Creates a new instance of the `WhisperForConditionalGeneration` class. * @param {Object} config Configuration object for the model. @@ -1830,34 +1769,6 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { return super.generate(inputs, generation_config, logits_processor) } - /** - * Loads a pre-trained model from the given `pretrained_model_name_or_path`. - * - * @param {string} pretrained_model_name_or_path The path to the pre-trained model. - * @param {PretrainedOptions} options Additional options for loading the model. For more information, @see {@link PreTrainedModel.from_pretrained}. - * - * @returns {Promise} A new instance of the `WhisperForConditionalGeneration` class. - */ - static async from_pretrained(pretrained_model_name_or_path, { - quantized = true, - progress_callback = null, - config = null, - cache_dir = null, - local_files_only = false, - revision = 'main', - } = {}) { - let info = await seq2seqLoadModel(pretrained_model_name_or_path, { - quantized, - progress_callback, - config, - cache_dir, - local_files_only, - revision, - }); - // @ts-ignore - return new this(...info); - } - /** * Gets the start beams for generating outputs. * @param {Array} inputTokenIds Array of input token IDs. @@ -1895,9 +1806,7 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { * @returns {Promise} The model output. */ async forward(model_inputs) { - return await seq2seq_forward(this, model_inputs, { - encoder_input_name: 'input_features', - }); + return await seq2seqForward(this, model_inputs); } } ////////////////////////////////////////////////// @@ -1923,34 +1832,6 @@ export class VisionEncoderDecoderModel extends PreTrainedModel { this.dim_kv = this.config.decoder.n_embd / this.num_heads; } - /** - * Loads a pre-trained model from the given `pretrained_model_name_or_path`. - * - * @param {string} pretrained_model_name_or_path The path to the pre-trained model. - * @param {PretrainedOptions} options Additional options for loading the model. For more information, @see {@link PreTrainedModel.from_pretrained}. - * - * @returns {Promise} A new instance of the `VisionEncoderDecoderModel` class. - */ - static async from_pretrained(pretrained_model_name_or_path, { - quantized = true, - progress_callback = null, - config = null, - cache_dir = null, - local_files_only = false, - revision = 'main', - } = {}) { - let info = await encoderDecoderLoadModel(pretrained_model_name_or_path, { - quantized, - progress_callback, - config, - cache_dir, - local_files_only, - revision, - }); - // @ts-ignore - return new this(...info); - } - /** * Generate beam search outputs for the given input pixels and number of output tokens. * @@ -1991,8 +1872,7 @@ export class VisionEncoderDecoderModel extends PreTrainedModel { * @returns {Promise} The output tensor of the model. */ async forward(model_inputs) { - return await seq2seq_forward(this, model_inputs, { - encoder_input_name: 'pixel_values', + return await seq2seqForward(this, model_inputs, { add_decoder_pkv: false }) } @@ -2010,14 +1890,28 @@ export class CLIPModel extends CLIPPreTrainedModel { ////////////////////////////////////////////////// // GPT2 models -export class GPT2PreTrainedModel extends PreTrainedModel { } -/** - * GPT2Model is not compatible with `.generate()`, as it doesn't have a language model head. - * @extends GPT2PreTrainedModel - */ +export class GPT2PreTrainedModel extends PreTrainedModel { + /** + * Creates a new instance of the `GPT2PreTrainedModel` class. + * @param {Object} config The configuration of the model. + * @param {any} session The ONNX session containing the model weights. + */ + constructor(config, session) { + super(config, session); + + // config doesn't contain pad_token_id, so we assume it is the eos_token_id + this.config.pad_token_id = this.config.eos_token_id + + this.num_heads = this.config.n_head + this.num_layers = this.config.n_layer + this.dim_kv = this.config.n_embd / this.num_heads; + } +} + export class GPT2Model extends GPT2PreTrainedModel { + /** - * + * GPT2Model is not compatible with `.generate()`, as it doesn't have a language model head. * @param {...any} args * @throws {Error} * @returns {Promise} @@ -2034,21 +1928,6 @@ export class GPT2Model extends GPT2PreTrainedModel { * @extends GPT2PreTrainedModel */ export class GPT2LMHeadModel extends GPT2PreTrainedModel { - /** - * Creates a new instance of the `GPT2LMHeadModel` class. - * @param {Object} config The configuration of the model. - * @param {any} session The ONNX session containing the model weights. - */ - constructor(config, session) { - super(config, session); - - // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id - - this.num_heads = this.config.n_head - this.num_layers = this.config.n_layer - this.dim_kv = this.config.n_embd / this.num_heads; - } /** * Initializes and returns the beam for text generation task @@ -2058,7 +1937,7 @@ export class GPT2LMHeadModel extends GPT2PreTrainedModel { * @returns {any} A Beam object representing the initialized beam. */ getStartBeams(inputTokenIds, numOutputTokens, inputs_attention_mask) { - return textgenStartBeams(this, inputTokenIds, numOutputTokens, inputs_attention_mask) + return decoderStartBeams(this, inputTokenIds, numOutputTokens, inputs_attention_mask) } /** @@ -2067,7 +1946,7 @@ export class GPT2LMHeadModel extends GPT2PreTrainedModel { * @returns {Promise} The updated beam after a single generation step. */ async runBeam(beam) { - return await textgenRunBeam(this, beam); + return await decoderRunBeam(this, beam); } /** @@ -2076,7 +1955,7 @@ export class GPT2LMHeadModel extends GPT2PreTrainedModel { * @param {number} newTokenId The new generated token id to be added to the beam. */ updateBeam(beam, newTokenId) { - return textgenUpdatebeam(beam, newTokenId); + return decoderUpdatebeam(beam, newTokenId); } /** @@ -2085,7 +1964,7 @@ export class GPT2LMHeadModel extends GPT2PreTrainedModel { * @returns {Promise} The output tensor of the model. */ async forward(model_inputs) { - return await textgen_forward(this, model_inputs) + return await decoderForward(this, model_inputs); } } @@ -2093,7 +1972,23 @@ export class GPT2LMHeadModel extends GPT2PreTrainedModel { // TODO // } ////////////////////////////////////////////////// -export class GPTNeoPreTrainedModel extends PreTrainedModel { } +export class GPTNeoPreTrainedModel extends PreTrainedModel { + /** + * Creates a new instance of the `GPTNeoPreTrainedModel` class. + * @param {Object} config The configuration of the model. + * @param {any} session The ONNX session containing the model weights. + */ + constructor(config, session) { + super(config, session); + + // config doesn't contain pad_token_id, so we assume it is the eos_token_id + this.config.pad_token_id = this.config.eos_token_id + + this.num_heads = this.config.num_heads; + this.num_layers = this.config.num_layers; + this.dim_kv = this.config.hidden_size / this.num_heads; + } +} export class GPTNeoModel extends GPTNeoPreTrainedModel { /** * @@ -2109,21 +2004,6 @@ export class GPTNeoModel extends GPTNeoPreTrainedModel { } export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { - /** - * Creates a new instance of the `GPTNeoForCausalLM` class. - * @param {Object} config The configuration of the model. - * @param {any} session The ONNX session containing the model weights. - */ - constructor(config, session) { - super(config, session); - - // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id - - this.num_heads = this.config.num_heads; - this.num_layers = this.config.num_layers; - this.dim_kv = this.config.hidden_size / this.num_heads; - } /** * Initializes and returns the beam for text generation task @@ -2133,7 +2013,7 @@ export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { * @returns {any} A Beam object representing the initialized beam. */ getStartBeams(inputTokenIds, numOutputTokens, inputs_attention_mask) { - return textgenStartBeams(this, inputTokenIds, numOutputTokens, inputs_attention_mask) + return decoderStartBeams(this, inputTokenIds, numOutputTokens, inputs_attention_mask) } /** @@ -2142,7 +2022,7 @@ export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { * @returns {Promise} The updated beam after a single generation step. */ async runBeam(beam) { - return await textgenRunBeam(this, beam); + return await decoderRunBeam(this, beam); } /** @@ -2151,7 +2031,7 @@ export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { * @param {number} newTokenId The new generated token id to be added to the beam. */ updateBeam(beam, newTokenId) { - return textgenUpdatebeam(beam, newTokenId); + return decoderUpdatebeam(beam, newTokenId); } /** @@ -2160,13 +2040,29 @@ export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { * @returns {Promise} The output tensor of the model. */ async forward(model_inputs) { - return await textgen_forward(this, model_inputs) + return await decoderForward(this, model_inputs); } } ////////////////////////////////////////////////// // CodeGen models -export class CodeGenPreTrainedModel extends PreTrainedModel { } +export class CodeGenPreTrainedModel extends PreTrainedModel { + /** + * Creates a new instance of the `CodeGenPreTrainedModel` class. + * @param {Object} config The model configuration object. + * @param {Object} session The ONNX session object. + */ + constructor(config, session) { + super(config, session); + + // config doesn't contain pad_token_id, so we assume it is the eos_token_id + this.config.pad_token_id = this.config.eos_token_id + + this.num_heads = this.config.n_head + this.num_layers = this.config.n_layer + this.dim_kv = this.config.n_embd / this.num_heads; + } +} /** * CodeGenModel is a class representing a code generation model without a language model head. * @@ -2194,21 +2090,6 @@ export class CodeGenModel extends CodeGenPreTrainedModel { * @extends CodeGenPreTrainedModel */ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { - /** - * Creates a new instance of the `CodeGenForCausalLM` class. - * @param {Object} config The model configuration object. - * @param {Object} session The ONNX session object. - */ - constructor(config, session) { - super(config, session); - - // config doesn't contain pad_token_id, so we assume it is the eos_token_id - this.config.pad_token_id = this.config.eos_token_id - - this.num_heads = this.config.n_head - this.num_layers = this.config.n_layer - this.dim_kv = this.config.n_embd / this.num_heads; - } /** * Initializes and returns the beam for text generation task @@ -2218,7 +2099,7 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { * @returns {any} A Beam object representing the initialized beam. */ getStartBeams(inputTokenIds, numOutputTokens, inputs_attention_mask) { - return textgenStartBeams(this, inputTokenIds, numOutputTokens, inputs_attention_mask) + return decoderStartBeams(this, inputTokenIds, numOutputTokens, inputs_attention_mask) } /** @@ -2227,7 +2108,7 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { * @returns {Promise} The updated beam after a single generation step. */ async runBeam(beam) { - return await textgenRunBeam(this, beam); + return await decoderRunBeam(this, beam); } /** @@ -2236,7 +2117,7 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { * @param {number} newTokenId The new generated token id to be added to the beam. */ updateBeam(beam, newTokenId) { - return textgenUpdatebeam(beam, newTokenId); + return decoderUpdatebeam(beam, newTokenId); } /** @@ -2245,7 +2126,7 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel { * @returns {Promise} The output tensor of the model. */ async forward(model_inputs) { - return await textgen_forward(this, model_inputs) + return await decoderForward(this, model_inputs); } } @@ -2258,8 +2139,7 @@ export class ViTForImageClassification extends ViTPreTrainedModel { * @param {any} model_inputs */ async _call(model_inputs) { - let logits = (await super._call(model_inputs)).logits; - return new SequenceClassifierOutput(logits) + return new SequenceClassifierOutput(await super._call(model_inputs)); } } ////////////////////////////////////////////////// @@ -2271,8 +2151,7 @@ export class DetrForObjectDetection extends DetrPreTrainedModel { * @param {any} model_inputs */ async _call(model_inputs) { - let output = (await super._call(model_inputs)); - return new DetrObjectDetectionOutput(output.logits, output.pred_boxes) + return new DetrObjectDetectionOutput(await super._call(model_inputs)); } } @@ -2283,17 +2162,18 @@ export class DetrForSegmentation extends DetrPreTrainedModel { * @returns {Promise} Object containing segmentation outputs */ async _call(model_inputs) { - let output = (await super._call(model_inputs)); - return new DetrSegmentationOutput(output.logits, output.pred_boxes, output.pred_masks); + return new DetrSegmentationOutput(await super._call(model_inputs)); } } export class DetrObjectDetectionOutput extends ModelOutput { /** - * @param {Tensor} logits - * @param {Tensor} pred_boxes + * @param {Object} output The output of the model. + * @param {Tensor} output.logits Classification logits (including no-object) for all queries. + * @param {Tensor} output.pred_boxes Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). + * These values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding possible padding). */ - constructor(logits, pred_boxes) { + constructor({ logits, pred_boxes }) { super(); this.logits = logits; this.pred_boxes = pred_boxes; @@ -2301,13 +2181,13 @@ export class DetrObjectDetectionOutput extends ModelOutput { } export class DetrSegmentationOutput extends ModelOutput { - /** - * @param {Tensor} logits The output logits of the model. - * @param {Tensor} pred_boxes Predicted boxes. - * @param {Tensor} pred_masks Predicted masks. + * @param {Object} output The output of the model. + * @param {Tensor} output.logits The output logits of the model. + * @param {Tensor} output.pred_boxes Predicted boxes. + * @param {Tensor} output.pred_masks Predicted masks. */ - constructor(logits, pred_boxes, pred_masks) { + constructor({ logits, pred_boxes, pred_masks }) { super(); this.logits = logits; this.pred_boxes = pred_boxes; @@ -2327,24 +2207,21 @@ export class SamModel extends SamPreTrainedModel { * @todo Add support for `input_labels`, `input_boxes`, `input_masks`, and `image_embeddings`. */ async _call(model_inputs) { - // TODO split into encoder and decoder - let output = (await super._call(model_inputs)); - return new SamImageSegmentationOutput(output.iou_scores, output.pred_masks); + return new SamImageSegmentationOutput(await super._call(model_inputs)); } } /** * Base class for Segment-Anything model's output. - * - * @extends ModelOutput */ export class SamImageSegmentationOutput extends ModelOutput { /** - * @param {Tensor} iou_scores The output logits of the model. - * @param {Tensor} pred_masks Predicted boxes. + * @param {Object} output The output of the model. + * @param {Tensor} output.iou_scores The output logits of the model. + * @param {Tensor} output.pred_masks Predicted boxes. */ - constructor(iou_scores, pred_masks) { + constructor({ iou_scores, pred_masks }) { super(); this.iou_scores = iou_scores; this.pred_masks = pred_masks; @@ -2372,6 +2249,7 @@ export class MarianModel extends MarianPreTrainedModel { } export class MarianMTModel extends MarianPreTrainedModel { + /** * Creates a new instance of the `MarianMTModel` class. * @param {Object} config The model configuration object. @@ -2393,34 +2271,6 @@ export class MarianMTModel extends MarianPreTrainedModel { this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads; } - /** - * Loads a pre-trained model from the given `pretrained_model_name_or_path`. - * - * @param {string} pretrained_model_name_or_path The path to the pre-trained model. - * @param {PretrainedOptions} options Additional options for loading the model. For more information, @see {@link PreTrainedModel.from_pretrained}. - * - * @returns {Promise} A new instance of the `MarianMTModel` class. - */ - static async from_pretrained(pretrained_model_name_or_path, { - quantized = true, - progress_callback = null, - config = null, - cache_dir = null, - local_files_only = false, - revision = 'main', - } = {}) { - let info = await seq2seqLoadModel(pretrained_model_name_or_path, { - quantized, - progress_callback, - config, - cache_dir, - local_files_only, - revision, - }); - // @ts-ignore - return new this(...info); - } - /** * Initializes and returns the beam for text generation task * @param {any[]} inputs The input token ids. @@ -2454,7 +2304,7 @@ export class MarianMTModel extends MarianPreTrainedModel { * @returns {Promise} */ async forward(model_inputs) { - return await seq2seq_forward(this, model_inputs); + return await seq2seqForward(this, model_inputs); } } ////////////////////////////////////////////////// @@ -2478,6 +2328,7 @@ export class M2M100Model extends M2M100PreTrainedModel { } export class M2M100ForConditionalGeneration extends M2M100PreTrainedModel { + /** * Creates a new instance of the `M2M100ForConditionalGeneration` class. * @param {Object} config The model configuration object. @@ -2499,33 +2350,6 @@ export class M2M100ForConditionalGeneration extends M2M100PreTrainedModel { this.encoder_dim_kv = this.config.d_model / this.num_encoder_heads; } - /** - * Loads a pre-trained model from the given `pretrained_model_name_or_path`. - * - * @param {string} pretrained_model_name_or_path The path to the pre-trained model. - * @param {PretrainedOptions} options Additional options for loading the model. For more information, @see {@link PreTrainedModel.from_pretrained}. - * - * @returns {Promise} A new instance of the `M2M100ForConditionalGeneration` class. - */ - static async from_pretrained(pretrained_model_name_or_path, { - quantized = true, - progress_callback = null, - config = null, - cache_dir = null, - local_files_only = false, - revision = 'main', - } = {}) { - let info = await seq2seqLoadModel(pretrained_model_name_or_path, { - quantized, - progress_callback, - config, - cache_dir, - local_files_only, - revision, - }); - // @ts-ignore - return new this(...info); - } /** * Initializes and returns the beam for text generation task @@ -2560,7 +2384,7 @@ export class M2M100ForConditionalGeneration extends M2M100PreTrainedModel { * @returns {Promise} */ async forward(model_inputs) { - return await seq2seq_forward(this, model_inputs); + return await seq2seqForward(this, model_inputs); } } ////////////////////////////////////////////////// @@ -2577,8 +2401,9 @@ export class M2M100ForConditionalGeneration extends M2M100PreTrainedModel { export class PretrainedMixin { /** * Mapping from model type to model class. + * @type {Map[]} */ - static MODEL_CLASS_MAPPING = Object.create(null); + static MODEL_CLASS_MAPPINGS = null; /** * Whether to attempt to instantiate the base class (`PretrainedModel`) if @@ -2586,26 +2411,8 @@ export class PretrainedMixin { */ static BASE_IF_FAIL = false; - /** - * The function to use to load the pretrained model. - */ - static LOAD_FUNCTION = null; - /** - * Instantiate one of the model classes of the library from a pretrained model. - * - * The model class to instantiate is selected based on the `model_type` property of the config object - * (either passed as an argument or loaded from `pretrained_model_name_or_path` if possible) - * - * @param {string} pretrained_model_name_or_path The name or path of the pretrained model. Can be either: - * - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - * Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a - * user or organization name, like `dbmdz/bert-base-german-cased`. - * - A path to a *directory* containing model weights, e.g., `./my_model_directory/`. - * @param {PretrainedOptions} options Additional options for loading the model. - * - * @returns {Promise} A new instance of the PreTrainedModel class. - */ + /** @type {PreTrainedModel.from_pretrained} */ static async from_pretrained(pretrained_model_name_or_path, { quantized = true, progress_callback = null, @@ -2614,29 +2421,158 @@ export class PretrainedMixin { local_files_only = false, revision = 'main', } = {}) { - if (this.LOAD_FUNCTION === null) { - throw new Error("`LOAD_FUNCTION` not implemented for this model"); - } - let info = await this.LOAD_FUNCTION(pretrained_model_name_or_path, { + let options = { quantized, progress_callback, config, cache_dir, local_files_only, revision, - }); + } + config = await AutoConfig.from_pretrained(pretrained_model_name_or_path, options); - let cls = this.MODEL_CLASS_MAPPING[info[0].model_type]; - if (!cls) { - if (this.BASE_IF_FAIL) { - console.warn(`Unknown model class "${info[0].model_type}", attempting to construct from base class.`); - cls = PreTrainedModel; - } else { - throw Error(`Unsupported model type: ${info[0].model_type}`) + if (!this.MODEL_CLASS_MAPPINGS) { + throw new Error("`MODEL_CLASS_MAPPINGS` not implemented for this type of `AutoClass`: " + this.name); + } + + let modelClass; + for (let MODEL_CLASS_MAPPING of this.MODEL_CLASS_MAPPINGS) { + modelClass = MODEL_CLASS_MAPPING.get(config.model_type); + if (!modelClass) { + continue; // Item not found in this mapping } + + return await modelClass.from_pretrained(pretrained_model_name_or_path, options); + } + + if (this.BASE_IF_FAIL) { + console.warn(`Unknown model class "${config.model_type}", attempting to construct from base class.`); + return await PreTrainedModel.from_pretrained(pretrained_model_name_or_path, options); + } else { + throw Error(`Unsupported model type: ${config.model_type}`) } - return new cls(...info); + } +} + +const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([ + ['bert', BertModel], + ['albert', AlbertModel], + ['distilbert', DistilBertModel], + ['roberta', RobertaModel], + ['clip', CLIPModel], + ['mobilebert', MobileBertModel], + ['squeezebert', SqueezeBertModel], + + ['sam', SamModel], // TODO change to encoder-decoder when model is split correctly +]); + +const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([ + ['t5', T5Model], + ['mt5', MT5Model], + ['bart', BartModel], + ['marian', MarianModel], + ['whisper', WhisperModel], + ['m2m_100', M2M100Model], +]); + + +const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([ + ['gpt2', GPT2Model], + ['gpt_neo', GPTNeoModel], + ['codegen', CodeGenModel], +]); + +const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([ + ['bert', BertForSequenceClassification], + ['albert', AlbertForSequenceClassification], + ['distilbert', DistilBertForSequenceClassification], + ['roberta', RobertaForSequenceClassification], + ['bart', BartForSequenceClassification], + ['mobilebert', MobileBertForSequenceClassification], + ['squeezebert', SqueezeBertForSequenceClassification], +]); + +const MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = new Map([ + ['bert', BertForTokenClassification], + ['distilbert', DistilBertForTokenClassification], +]); + +const MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES = new Map([ + ['t5', T5ForConditionalGeneration], + ['mt5', MT5ForConditionalGeneration], + ['bart', BartForConditionalGeneration], + ['whisper', WhisperForConditionalGeneration], + ['marian', MarianMTModel], + ['m2m_100', M2M100ForConditionalGeneration], +]); + +const MODEL_WITH_LM_HEAD_MAPPING_NAMES = new Map([ + ['gpt2', GPT2LMHeadModel], + ['gpt_neo', GPTNeoForCausalLM], + ['codegen', CodeGenForCausalLM], +]); + +const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([ + ['bert', BertForMaskedLM], + ['albert', AlbertForMaskedLM], + ['distilbert', DistilBertForMaskedLM], + ['roberta', RobertaForMaskedLM], + ['mobilebert', MobileBertForMaskedLM], + ['squeezebert', SqueezeBertForMaskedLM], +]); + +const MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = new Map([ + ['bert', BertForQuestionAnswering], + ['albert', AlbertForQuestionAnswering], + ['distilbert', DistilBertForQuestionAnswering], + ['roberta', RobertaForQuestionAnswering], + ['mobilebert', MobileBertForQuestionAnswering], + ['squeezebert', SqueezeBertForQuestionAnswering], +]); + +const MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES = new Map([ + ['vision-encoder-decoder', VisionEncoderDecoderModel], +]); + +const MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES = new Map([ + ['vit', ViTForImageClassification], +]); + +const MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = new Map([ + ['detr', DetrForObjectDetection], +]); + +const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([ + ['detr', DetrForSegmentation], +]); + +const MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = new Map([ + ['sam', SamModel], +]); + +const MODEL_CLASS_TYPE_MAPPING = [ + [MODEL_MAPPING_NAMES_ENCODER_ONLY, EncoderOnlyModelType], + [MODEL_MAPPING_NAMES_ENCODER_DECODER, EncoderDecoderModelType], + [MODEL_MAPPING_NAMES_DECODER_ONLY, DecoderOnlyModelType], + [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, EncoderOnlyModelType], + [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, EncoderOnlyModelType], + [MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES, Seq2SeqModelType], + [MODEL_WITH_LM_HEAD_MAPPING_NAMES, DecoderOnlyModelType], + [MODEL_FOR_MASKED_LM_MAPPING_NAMES, EncoderOnlyModelType], + [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, EncoderOnlyModelType], + [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, EncoderDecoderModelType], + [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, EncoderOnlyModelType], + [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, EncoderOnlyModelType], + [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, EncoderOnlyModelType], + [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, EncoderOnlyModelType], +]; + +for (let [mappings, type] of MODEL_CLASS_TYPE_MAPPING) { + // @ts-ignore + for (let [name, model] of mappings.entries()) { + MODEL_TYPE_MAPPING.set(model.name, type); + MODEL_CLASS_MAPPING.set(model.name, name); } } @@ -2648,27 +2584,8 @@ export class PretrainedMixin { * let model = await AutoModel.from_pretrained('bert-base-uncased'); */ export class AutoModel extends PretrainedMixin { - static LOAD_FUNCTION = loadAutoModel; + static MODEL_CLASS_MAPPINGS = [MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_MAPPING_NAMES_ENCODER_DECODER, MODEL_MAPPING_NAMES_DECODER_ONLY]; static BASE_IF_FAIL = true; - static MODEL_CLASS_MAPPING = { - 'bert': BertModel, - 'albert': AlbertModel, - 'distilbert': DistilBertModel, - 't5': T5Model, - 'mt5': MT5Model, - 'gpt2': GPT2Model, - 'gpt_neo': GPTNeoModel, - 'codegen': CodeGenModel, - 'bart': BartModel, - 'roberta': RobertaModel, - 'whisper': WhisperModel, - 'clip': CLIPModel, - 'mobilebert': MobileBertModel, - 'squeezebert': SqueezeBertModel, - 'marian': MarianModel, - 'm2m_100': M2M100Model, - 'sam': SamModel, - } } /** @@ -2679,16 +2596,7 @@ export class AutoModel extends PretrainedMixin { * let model = await AutoModelForSequenceClassification.from_pretrained('distilbert-base-uncased-finetuned-sst-2-english'); */ export class AutoModelForSequenceClassification extends PretrainedMixin { - static LOAD_FUNCTION = loadModel; - static MODEL_CLASS_MAPPING = { - 'bert': BertForSequenceClassification, - 'albert': AlbertForSequenceClassification, - 'distilbert': DistilBertForSequenceClassification, - 'roberta': RobertaForSequenceClassification, - 'bart': BartForSequenceClassification, - 'mobilebert': MobileBertForSequenceClassification, - 'squeezebert': SqueezeBertForSequenceClassification, - } + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES]; } /** @@ -2699,14 +2607,9 @@ export class AutoModelForSequenceClassification extends PretrainedMixin { * let model = await AutoModelForTokenClassification.from_pretrained('Davlan/distilbert-base-multilingual-cased-ner-hrl'); */ export class AutoModelForTokenClassification extends PretrainedMixin { - static LOAD_FUNCTION = loadModel; - static MODEL_CLASS_MAPPING = { - 'bert': BertForTokenClassification, - 'distilbert': DistilBertForTokenClassification, - } + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES]; } - /** * Helper class which is used to instantiate pretrained sequence-to-sequence models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. @@ -2715,15 +2618,7 @@ export class AutoModelForTokenClassification extends PretrainedMixin { * let model = await AutoModelForSeq2SeqLM.from_pretrained('t5-small'); */ export class AutoModelForSeq2SeqLM extends PretrainedMixin { - static LOAD_FUNCTION = seq2seqLoadModel; - static MODEL_CLASS_MAPPING = { - 't5': T5ForConditionalGeneration, - 'mt5': MT5ForConditionalGeneration, - 'bart': BartForConditionalGeneration, - 'whisper': WhisperForConditionalGeneration, - 'marian': MarianMTModel, - 'm2m_100': M2M100ForConditionalGeneration, - } + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES]; } /** @@ -2734,12 +2629,7 @@ export class AutoModelForSeq2SeqLM extends PretrainedMixin { * let model = await AutoModelForCausalLM.from_pretrained('gpt2'); */ export class AutoModelForCausalLM extends PretrainedMixin { - static LOAD_FUNCTION = decoderLoadModel; - static MODEL_CLASS_MAPPING = { - 'gpt2': GPT2LMHeadModel, - 'gpt_neo': GPTNeoForCausalLM, - 'codegen': CodeGenForCausalLM, - } + static MODEL_CLASS_MAPPINGS = [MODEL_WITH_LM_HEAD_MAPPING_NAMES]; } /** @@ -2750,15 +2640,7 @@ export class AutoModelForCausalLM extends PretrainedMixin { * let model = await AutoModelForMaskedLM.from_pretrained('bert-base-uncased'); */ export class AutoModelForMaskedLM extends PretrainedMixin { - static LOAD_FUNCTION = loadAutoModel; - static MODEL_CLASS_MAPPING = { - 'bert': BertForMaskedLM, - 'albert': AlbertForMaskedLM, - 'distilbert': DistilBertForMaskedLM, - 'roberta': RobertaForMaskedLM, - 'mobilebert': MobileBertForMaskedLM, - 'squeezebert': SqueezeBertForMaskedLM, - } + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_MASKED_LM_MAPPING_NAMES]; } /** @@ -2769,15 +2651,7 @@ export class AutoModelForMaskedLM extends PretrainedMixin { * let model = await AutoModelForQuestionAnswering.from_pretrained('distilbert-base-cased-distilled-squad'); */ export class AutoModelForQuestionAnswering extends PretrainedMixin { - static LOAD_FUNCTION = loadModel; - static MODEL_CLASS_MAPPING = { - 'bert': BertForQuestionAnswering, - 'albert': AlbertForQuestionAnswering, - 'distilbert': DistilBertForQuestionAnswering, - 'roberta': RobertaForQuestionAnswering, - 'mobilebert': MobileBertForQuestionAnswering, - 'squeezebert': SqueezeBertForQuestionAnswering, - } + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES]; } /** @@ -2788,10 +2662,7 @@ export class AutoModelForQuestionAnswering extends PretrainedMixin { * let model = await AutoModelForVision2Seq.from_pretrained('nlpconnect/vit-gpt2-image-captioning'); */ export class AutoModelForVision2Seq extends PretrainedMixin { - static LOAD_FUNCTION = encoderDecoderLoadModel; - static MODEL_CLASS_MAPPING = { - 'vision-encoder-decoder': VisionEncoderDecoderModel - } + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES]; } /** @@ -2802,10 +2673,7 @@ export class AutoModelForVision2Seq extends PretrainedMixin { * let model = await AutoModelForImageClassification.from_pretrained('google/vit-base-patch16-224'); */ export class AutoModelForImageClassification extends PretrainedMixin { - static LOAD_FUNCTION = loadModel; - static MODEL_CLASS_MAPPING = { - 'vit': ViTForImageClassification, - } + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES]; } /** @@ -2816,10 +2684,7 @@ export class AutoModelForImageClassification extends PretrainedMixin { * let model = await AutoModelForImageSegmentation.from_pretrained('facebook/detr-resnet-50-panoptic'); */ export class AutoModelForImageSegmentation extends PretrainedMixin { - static LOAD_FUNCTION = loadModel; - static MODEL_CLASS_MAPPING = { - 'detr': DetrForSegmentation, - } + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES]; } /** @@ -2830,10 +2695,7 @@ export class AutoModelForImageSegmentation extends PretrainedMixin { * let model = await AutoModelForObjectDetection.from_pretrained('facebook/detr-resnet-50'); */ export class AutoModelForObjectDetection extends PretrainedMixin { - static LOAD_FUNCTION = loadModel; - static MODEL_CLASS_MAPPING = { - 'detr': DetrForObjectDetection, - } + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES]; } /** @@ -2844,21 +2706,19 @@ export class AutoModelForObjectDetection extends PretrainedMixin { * let model = await AutoModelForMaskGeneration.from_pretrained('Xenova/sam-vit-base'); */ export class AutoModelForMaskGeneration extends PretrainedMixin { - static LOAD_FUNCTION = loadModel; - static MODEL_CLASS_MAPPING = { - 'sam': SamModel, - } + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES]; } ////////////////////////////////////////////////// ////////////////////////////////////////////////// export class Seq2SeqLMOutput extends ModelOutput { /** - * @param {Tensor} logits The output logits of the model. - * @param {Tensor} past_key_values An tensor of key/value pairs that represent the previous state of the model. - * @param {Tensor} encoder_outputs The output of the encoder in a sequence-to-sequence model. + * @param {Object} output The output of the model. + * @param {Tensor} output.logits The output logits of the model. + * @param {Tensor} output.past_key_values An tensor of key/value pairs that represent the previous state of the model. + * @param {Tensor} output.encoder_outputs The output of the encoder in a sequence-to-sequence model. */ - constructor(logits, past_key_values, encoder_outputs) { + constructor({ logits, past_key_values, encoder_outputs }) { super(); this.logits = logits; this.past_key_values = past_key_values; @@ -2866,45 +2726,78 @@ export class Seq2SeqLMOutput extends ModelOutput { } } +/** + * Base class for outputs of sentence classification models. + */ export class SequenceClassifierOutput extends ModelOutput { /** - * @param {Tensor} logits + * @param {Object} output The output of the model. + * @param {Tensor} output.logits classification (or regression if config.num_labels==1) scores (before SoftMax). */ - constructor(logits) { + constructor({ logits }) { super(); this.logits = logits; } } +/** + * Base class for outputs of token classification models. + */ export class TokenClassifierOutput extends ModelOutput { /** - * @param {Tensor} logits + * @param {Object} output The output of the model. + * @param {Tensor} output.logits Classification scores (before SoftMax). */ - constructor(logits) { + constructor({ logits }) { super(); this.logits = logits; } } - +/** + * Base class for masked language models outputs. + */ export class MaskedLMOutput extends ModelOutput { /** - * @param {Tensor} logits + * @param {Object} output The output of the model. + * @param {Tensor} output.logits Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). */ - constructor(logits) { + constructor({ logits }) { super(); this.logits = logits; } } +/** + * Base class for outputs of question answering models. + */ export class QuestionAnsweringModelOutput extends ModelOutput { /** - * @param {Tensor} start_logits The logits for start positions of the answer. - * @param {Tensor} end_logits The logits for end positions of the answer. + * @param {Object} output The output of the model. + * @param {Tensor} output.start_logits Span-start scores (before SoftMax). + * @param {Tensor} output.end_logits Span-end scores (before SoftMax). */ - constructor(start_logits, end_logits) { + constructor({ start_logits, end_logits }) { super(); this.start_logits = start_logits; this.end_logits = end_logits; } } + + +/** + * Base class for causal language model (or autoregressive) outputs. + */ +export class CausalLMOutputWithPast extends ModelOutput { + /** + * @param {Object} output The output of the model. + * @param {Tensor} output.logits Prediction scores of the language modeling head (scores for each vocabulary token before softmax). + * @param {Tensor} output.past_key_values Contains pre-computed hidden-states (key and values in the self-attention blocks) + * that can be used (see `past_key_values` input) to speed up sequential decoding. + */ + constructor({ logits, past_key_values }) { + super(); + this.logits = logits; + this.past_key_values = past_key_values; + } +} \ No newline at end of file diff --git a/src/utils/tensor.js b/src/utils/tensor.js index af1f9fa1b..d452287a4 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -440,6 +440,61 @@ export class Tensor extends ONNXTensor { this.dims = calc_unsqueeze_dims(this.dims, dim); return this; } + + /** + * In-place version of @see {@link Tensor.flatten} + */ + flatten_(start_dim = 0, end_dim = -1) { + // TODO validate inputs + end_dim = (end_dim + this.dims.length) % this.dims.length; + + let dimsToKeepBefore = this.dims.slice(0, start_dim); + let dimsToFlatten = this.dims.slice(start_dim, end_dim + 1); + let dimsToKeepAfter = this.dims.slice(end_dim + 1); + + this.dims = [...dimsToKeepBefore, dimsToFlatten.reduce((a, b) => a * b, 1), ...dimsToKeepAfter] + return this; + } + + /** + * Flattens input by reshaping it into a one-dimensional tensor. + * If `start_dim` or `end_dim` are passed, only dimensions starting with `start_dim` + * and ending with `end_dim` are flattened. The order of elements in input is unchanged. + * @param {number} start_dim the first dim to flatten + * @param {number} end_dim the last dim to flatten + * @returns The flattened tensor. + */ + flatten(start_dim = 0, end_dim = -1) { + return this.clone().flatten_(start_dim, end_dim); + } + + /** + * Returns a new tensor with the same data as the `self` tensor but of a different `shape`. + * @param {...number} dims the desired size + * @returns {Tensor} The tensor with the same data but different shape + */ + view(...dims) { + // TODO: validate dims + let inferredIndex = -1; + for (let i = 0; i < dims.length; ++i) { + if (dims[i] === -1) { + if (inferredIndex !== -1) { + throw new Error("Only one dimension can be inferred"); + } + inferredIndex = i; + } + } + + if (inferredIndex !== -1) { + // Some dimension must be inferred + const productOther = dims.reduce((product, curr, index) => { + return index !== inferredIndex ? product * curr : product + }, 1); + + dims[inferredIndex] = this.data.length / productOther; + } + return new Tensor(this.type, this.data, dims); // NOTE: uses same underlying storage + } } /** diff --git a/tests/models.test.js b/tests/models.test.js new file mode 100644 index 000000000..b012581de --- /dev/null +++ b/tests/models.test.js @@ -0,0 +1,90 @@ +/* + * Test that models loaded outside of the `pipeline` function work correctly (e.g., `AutoModel.from_pretrained(...)`); + */ + +import { + AutoTokenizer, + AutoModel, + + BertModel, + GPT2Model, + T5Model, + + BertTokenizer, + GPT2Tokenizer, + T5Tokenizer, + +} from '../src/transformers.js'; + +import { init, m, MAX_TEST_EXECUTION_TIME } from './init.js'; + +import { compare } from './test_utils.js'; + +// Initialise the testing environment +init(); + + +describe('Loading models', () => { + + // List all models which will be tested + const models_to_test = [ + // [name, modelClass, tokenizerClass] + ['bert-base-uncased', BertModel, BertTokenizer], // Encoder-only + ['gpt2', GPT2Model, GPT2Tokenizer], // Decoder-only + ['t5-small', T5Model, T5Tokenizer], // Encoder-decoder + ]; + + let texts = [ + 'Once upon a time', + 'I like to eat apples', + ]; + + for (let [name, modelClass, tokenizerClass] of models_to_test) { + + // Test that both the auto model and the specific model work + let tokenizers = [AutoTokenizer, tokenizerClass]; + let models = [AutoModel, modelClass]; + + for (let i = 0; i < tokenizers.length; ++i) { + const tokenizerClassToTest = tokenizers[i]; + const modelClassToTest = models[i]; + + it(`${name} (${modelClassToTest.name})`, async () => { + const model_id = m(name); + + // Load model and tokenizer + let tokenizer = await tokenizerClassToTest.from_pretrained(model_id); + let model = await modelClassToTest.from_pretrained(model_id); + + let tests = [ + texts[0], // single + texts, // batched + ] + for (let test of tests) { + let encodings = await tokenizer(test, { truncation: true, padding: true }); + let output = await model(encodings); + + if (output.logits) { + // Ensure correct shapes + let expected_shape = [...encodings.input_ids.dims, model.config.vocab_size]; + let actual_shape = output.logits.dims; + compare(expected_shape, actual_shape); + } else if (output.last_hidden_state) { + let expected_shape = [...encodings.input_ids.dims, model.config.d_model]; + let actual_shape = output.last_hidden_state.dims; + compare(expected_shape, actual_shape); + } else { + console.warn('Unexpected output', output); + throw new Error('Unexpected output'); + } + + } + + await model.dispose(); + + }, MAX_TEST_EXECUTION_TIME); + + } + } + +}); \ No newline at end of file diff --git a/tests/pipelines.test.js b/tests/pipelines.test.js index 26d5100f6..5d988ef83 100644 --- a/tests/pipelines.test.js +++ b/tests/pipelines.test.js @@ -1,50 +1,11 @@ import { pipeline, cos_sim } from '../src/transformers.js'; import { init, m, MAX_TEST_EXECUTION_TIME } from './init.js'; - +import { compare } from './test_utils.js'; // Initialise the testing environment init(); - -function compare(val1, val2, tol = 0.1) { - if ( - (val1 !== null && val2 !== null) && - (typeof val1 === 'object' && typeof val2 === 'object') - ) { - // Both are non-null objects - - if (Array.isArray(val1) && Array.isArray(val2)) { - expect(val1).toHaveLength(val2.length); - - for (let i = 0; i < val1.length; ++i) { - compare(val1[i], val2[i], tol); - } - - } else { - expect(Object.keys(val1)).toHaveLength(Object.keys(val2).length); - - for (let key in val1) { - compare(val1[key], val2[key]); - } - } - - } else { - // At least one of them is not an object - // First check that both have the same type - expect(typeof val1).toEqual(typeof val2); - - if (typeof val1 === 'number' && (!Number.isInteger(val1) || !Number.isInteger(val2))) { - // If both are numbers and at least one of them is not an integer - expect(val1).toBeCloseTo(val2, tol); - } else { - // Perform equality test - expect(val1).toEqual(val2); - } - } -} - - // NOTE: // Due to a memory leak in Jest, we cannot have multiple tests for a single model. // This is due to how model construction and destruction occurs, in `beforeAll` and `afterAll`, respectively. diff --git a/tests/test_utils.js b/tests/test_utils.js new file mode 100644 index 000000000..22d71947e --- /dev/null +++ b/tests/test_utils.js @@ -0,0 +1,43 @@ + +/** + * Deep equality test (for arrays and objects) with tolerance for floating point numbers + * @param {any} val1 The first item + * @param {any} val2 The second item + * @param {number} tol Tolerance for floating point numbers + */ +export function compare(val1, val2, tol = 0.1) { + if ( + (val1 !== null && val2 !== null) && + (typeof val1 === 'object' && typeof val2 === 'object') + ) { + // Both are non-null objects + + if (Array.isArray(val1) && Array.isArray(val2)) { + expect(val1).toHaveLength(val2.length); + + for (let i = 0; i < val1.length; ++i) { + compare(val1[i], val2[i], tol); + } + + } else { + expect(Object.keys(val1)).toHaveLength(Object.keys(val2).length); + + for (let key in val1) { + compare(val1[key], val2[key]); + } + } + + } else { + // At least one of them is not an object + // First check that both have the same type + expect(typeof val1).toEqual(typeof val2); + + if (typeof val1 === 'number' && (!Number.isInteger(val1) || !Number.isInteger(val2))) { + // If both are numbers and at least one of them is not an integer + expect(val1).toBeCloseTo(val2, tol); + } else { + // Perform equality test + expect(val1).toEqual(val2); + } + } +} \ No newline at end of file