From fb40cf151c92a0e49a98075da19001a4017275df Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Thu, 29 Jun 2023 12:14:37 +0200 Subject: [PATCH] Add `RobertaForTokenClassification` and an example checkpoint on Hub --- scripts/supported_models.py | 1 + src/models.js | 17 +++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/scripts/supported_models.py b/scripts/supported_models.py index e42613095..673af7cad 100644 --- a/scripts/supported_models.py +++ b/scripts/supported_models.py @@ -153,6 +153,7 @@ 'sentence-transformers/all-distilroberta-v1', 'sentence-transformers/all-roberta-large-v1', + 'julien-c/EsperBERTo-small-pos', ], 'sam': [ 'facebook/sam-vit-base', diff --git a/src/models.js b/src/models.js index 94c5b4bdb..3ece97429 100644 --- a/src/models.js +++ b/src/models.js @@ -1675,6 +1675,22 @@ export class RobertaForSequenceClassification extends RobertaPreTrainedModel { } } +/** + * RobertaForTokenClassification class for performing token classification on Roberta models. + * @extends RobertaPreTrainedModel + */ +export class RobertaForTokenClassification extends RobertaPreTrainedModel { + /** + * Calls the model on new inputs. + * + * @param {Object} model_inputs The inputs to the model. + * @returns {Promise} An object containing the model's output logits for token classification. + */ + async _call(model_inputs) { + return new TokenClassifierOutput(await super._call(model_inputs)); + } +} + /** * RobertaForQuestionAnswering class for performing question answering on Roberta models. * @extends RobertaPreTrainedModel @@ -2514,6 +2530,7 @@ const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([ const MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = new Map([ ['bert', BertForTokenClassification], ['distilbert', DistilBertForTokenClassification], + ['roberta', RobertaForTokenClassification], ]); const MODEL_FOR_SEQ_2_SEQ_MAPPING_NAMES = new Map([