From 9a51c998f083888e17c326217808b4ea4f1f7b6a Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 1 Jul 2023 03:25:01 +0200 Subject: [PATCH] Add xlm-roberta models (Fixes #177) --- src/models.js | 75 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/src/models.js b/src/models.js index 3ece97429..7d7e1bae8 100644 --- a/src/models.js +++ b/src/models.js @@ -1708,6 +1708,76 @@ export class RobertaForQuestionAnswering extends RobertaPreTrainedModel { } ////////////////////////////////////////////////// +////////////////////////////////////////////////// +// XLMRoberta models +export class XLMRobertaPreTrainedModel extends PreTrainedModel { } +export class XLMRobertaModel extends XLMRobertaPreTrainedModel { } + +/** + * XLMRobertaForMaskedLM class for performing masked language modeling on XLMRoberta models. + * @extends XLMRobertaPreTrainedModel + */ +export class XLMRobertaForMaskedLM extends XLMRobertaPreTrainedModel { + /** + * Calls the model on new inputs. + * + * @param {Object} model_inputs The inputs to the model. + * @returns {Promise} returned object + */ + async _call(model_inputs) { + return new MaskedLMOutput(await super._call(model_inputs)); + } +} + +/** + * XLMRobertaForSequenceClassification class for performing sequence classification on XLMRoberta models. + * @extends XLMRobertaPreTrainedModel + */ +export class XLMRobertaForSequenceClassification extends XLMRobertaPreTrainedModel { + /** + * Calls the model on new inputs. + * + * @param {Object} model_inputs The inputs to the model. + * @returns {Promise} returned object + */ + async _call(model_inputs) { + return new SequenceClassifierOutput(await super._call(model_inputs)); + } +} + +/** + * XLMRobertaForTokenClassification class for performing token classification on XLMRoberta models. + * @extends XLMRobertaPreTrainedModel + */ +export class XLMRobertaForTokenClassification extends XLMRobertaPreTrainedModel { + /** + * Calls the model on new inputs. + * + * @param {Object} model_inputs The inputs to the model. + * @returns {Promise} An object containing the model's output logits for token classification. + */ + async _call(model_inputs) { + return new TokenClassifierOutput(await super._call(model_inputs)); + } +} + +/** + * XLMRobertaForQuestionAnswering class for performing question answering on XLMRoberta models. + * @extends XLMRobertaPreTrainedModel + */ +export class XLMRobertaForQuestionAnswering extends XLMRobertaPreTrainedModel { + /** + * Calls the model on new inputs. + * + * @param {Object} model_inputs The inputs to the model. + * @returns {Promise} returned object + */ + async _call(model_inputs) { + return new QuestionAnsweringModelOutput(await super._call(model_inputs)); + } +} +////////////////////////////////////////////////// + ////////////////////////////////////////////////// // T5 models export class WhisperPreTrainedModel extends PreTrainedModel { }; @@ -2494,6 +2564,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([ ['albert', AlbertModel], ['distilbert', DistilBertModel], ['roberta', RobertaModel], + ['xlm-roberta', XLMRobertaModel], ['clip', CLIPModel], ['mobilebert', MobileBertModel], ['squeezebert', SqueezeBertModel], @@ -2522,6 +2593,7 @@ const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([ ['albert', AlbertForSequenceClassification], ['distilbert', DistilBertForSequenceClassification], ['roberta', RobertaForSequenceClassification], + ['xlm-roberta', XLMRobertaForSequenceClassification], ['bart', BartForSequenceClassification], ['mobilebert', MobileBertForSequenceClassification], ['squeezebert', SqueezeBertForSequenceClassification], @@ -2531,6 +2603,7 @@ const MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = new Map([ ['bert', BertForTokenClassification], ['distilbert', DistilBertForTokenClassification], ['roberta', RobertaForTokenClassification], + ['xlm-roberta', XLMRobertaForTokenClassification], ]); const MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES = new Map([ @@ -2553,6 +2626,7 @@ const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([ ['albert', AlbertForMaskedLM], ['distilbert', DistilBertForMaskedLM], ['roberta', RobertaForMaskedLM], + ['xlm-roberta', XLMRobertaForMaskedLM], ['mobilebert', MobileBertForMaskedLM], ['squeezebert', SqueezeBertForMaskedLM], ]); @@ -2562,6 +2636,7 @@ const MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = new Map([ ['albert', AlbertForQuestionAnswering], ['distilbert', DistilBertForQuestionAnswering], ['roberta', RobertaForQuestionAnswering], + ['xlm-roberta', XLMRobertaForQuestionAnswering], ['mobilebert', MobileBertForQuestionAnswering], ['squeezebert', SqueezeBertForQuestionAnswering], ]);