Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -1708,6 +1708,76 @@ export class RobertaForQuestionAnswering extends RobertaPreTrainedModel {
}
//////////////////////////////////////////////////

//////////////////////////////////////////////////
// XLMRoberta models
export class XLMRobertaPreTrainedModel extends PreTrainedModel { }
export class XLMRobertaModel extends XLMRobertaPreTrainedModel { }

/**
* XLMRobertaForMaskedLM class for performing masked language modeling on XLMRoberta models.
* @extends XLMRobertaPreTrainedModel
*/
export class XLMRobertaForMaskedLM extends XLMRobertaPreTrainedModel {
/**
* Calls the model on new inputs.
*
* @param {Object} model_inputs The inputs to the model.
* @returns {Promise<MaskedLMOutput>} returned object
*/
async _call(model_inputs) {
return new MaskedLMOutput(await super._call(model_inputs));
}
}

/**
* XLMRobertaForSequenceClassification class for performing sequence classification on XLMRoberta models.
* @extends XLMRobertaPreTrainedModel
*/
export class XLMRobertaForSequenceClassification extends XLMRobertaPreTrainedModel {
/**
* Calls the model on new inputs.
*
* @param {Object} model_inputs The inputs to the model.
* @returns {Promise<SequenceClassifierOutput>} returned object
*/
async _call(model_inputs) {
return new SequenceClassifierOutput(await super._call(model_inputs));
}
}

/**
* XLMRobertaForTokenClassification class for performing token classification on XLMRoberta models.
* @extends XLMRobertaPreTrainedModel
*/
export class XLMRobertaForTokenClassification extends XLMRobertaPreTrainedModel {
/**
* Calls the model on new inputs.
*
* @param {Object} model_inputs The inputs to the model.
* @returns {Promise<TokenClassifierOutput>} An object containing the model's output logits for token classification.
*/
async _call(model_inputs) {
return new TokenClassifierOutput(await super._call(model_inputs));
}
}

/**
* XLMRobertaForQuestionAnswering class for performing question answering on XLMRoberta models.
* @extends XLMRobertaPreTrainedModel
*/
export class XLMRobertaForQuestionAnswering extends XLMRobertaPreTrainedModel {
/**
* Calls the model on new inputs.
*
* @param {Object} model_inputs The inputs to the model.
* @returns {Promise<QuestionAnsweringModelOutput>} returned object
*/
async _call(model_inputs) {
return new QuestionAnsweringModelOutput(await super._call(model_inputs));
}
}
//////////////////////////////////////////////////

//////////////////////////////////////////////////
// T5 models
export class WhisperPreTrainedModel extends PreTrainedModel { };
Expand Down Expand Up @@ -2494,6 +2564,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
['albert', AlbertModel],
['distilbert', DistilBertModel],
['roberta', RobertaModel],
['xlm-roberta', XLMRobertaModel],
['clip', CLIPModel],
['mobilebert', MobileBertModel],
['squeezebert', SqueezeBertModel],
Expand Down Expand Up @@ -2522,6 +2593,7 @@ const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([
['albert', AlbertForSequenceClassification],
['distilbert', DistilBertForSequenceClassification],
['roberta', RobertaForSequenceClassification],
['xlm-roberta', XLMRobertaForSequenceClassification],
['bart', BartForSequenceClassification],
['mobilebert', MobileBertForSequenceClassification],
['squeezebert', SqueezeBertForSequenceClassification],
Expand All @@ -2531,6 +2603,7 @@ const MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = new Map([
['bert', BertForTokenClassification],
['distilbert', DistilBertForTokenClassification],
['roberta', RobertaForTokenClassification],
['xlm-roberta', XLMRobertaForTokenClassification],
]);

const MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES = new Map([
Expand All @@ -2553,6 +2626,7 @@ const MODEL_FOR_MASKED_LM_MAPPING_NAMES = new Map([
['albert', AlbertForMaskedLM],
['distilbert', DistilBertForMaskedLM],
['roberta', RobertaForMaskedLM],
['xlm-roberta', XLMRobertaForMaskedLM],
['mobilebert', MobileBertForMaskedLM],
['squeezebert', SqueezeBertForMaskedLM],
]);
Expand All @@ -2562,6 +2636,7 @@ const MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = new Map([
['albert', AlbertForQuestionAnswering],
['distilbert', DistilBertForQuestionAnswering],
['roberta', RobertaForQuestionAnswering],
['xlm-roberta', XLMRobertaForQuestionAnswering],
['mobilebert', MobileBertForQuestionAnswering],
['squeezebert', SqueezeBertForQuestionAnswering],
]);
Expand Down