diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py index d679ab9ba3..241ec4b5b6 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer.py @@ -123,10 +123,18 @@ def id_to_token(self, id): def token_to_id(self, token): """Convert a string token to an integer id.""" + if token in self._vocabulary_prefix: return self._vocabulary_prefix.index(token) - return int(self._sentence_piece.string_to_id(token).numpy()) + 1 + spm_token_id = self._sentence_piece.string_to_id(token) + + # OOV token + spm_unk_token_id = self._sentence_piece.string_to_id("") + if spm_token_id == spm_unk_token_id: + return self.unk_token_id + + return int(spm_token_id.numpy()) + 1 def tokenize(self, inputs): tokens = super().tokenize(inputs) diff --git a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer_test.py b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer_test.py index 912403fc4e..8ad1fddbfb 100644 --- a/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer_test.py +++ b/keras_nlp/models/xlm_roberta/xlm_roberta_tokenizer_test.py @@ -96,6 +96,10 @@ def test_id_to_token(self): def test_token_to_id(self): self.assertEqual(self.tokenizer.token_to_id("▁the"), 4) self.assertEqual(self.tokenizer.token_to_id("▁round"), 10) + # Test any random OOV token. + self.assertEqual(self.tokenizer.token_to_id(""), 3) + # Test a special token. + self.assertEqual(self.tokenizer.token_to_id(""), 1) def test_serialization(self): config = keras.utils.serialize_keras_object(self.tokenizer)