Skip to content

Commit

Permalink
PLAT-820: Make "vocab_size too small" a handled exceptions.
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 95f6de769584efdd01c005cc69a36aade4a289a8
  • Loading branch information
pimlock committed Feb 15, 2024
1 parent a977311 commit a63c68d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 12 deletions.
44 changes: 32 additions & 12 deletions src/gretel_synthetics/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import numpy as np
import sentencepiece as spm

from gretel_synthetics.errors import ParameterError
from smart_open import open as smart_open

if TYPE_CHECKING:
Expand Down Expand Up @@ -404,6 +405,16 @@ def _decode_from_ids(self, ids: List[int]) -> str:
#################


class VocabSizeTooSmall(ParameterError):
"""
Error that is raised when the `vocab_size` is too small for the given data.
This happens, when the `vocab_size` is set to a value that is smaller than the
number of required characters.
"""

...


class SentencePieceTokenizerTrainer(BaseTokenizerTrainer):
"""Train a tokenizer using Google SentencePiece."""

Expand Down Expand Up @@ -476,17 +487,27 @@ def _train(self, extra_symbols: Optional[List[str]] = None):
self.config.field_delimiter_token,
] + extra_symbols
logger.info("Training SentencePiece tokenizer")
spm.SentencePieceTrainer.Train(
input=self.config.training_data_path,
model_prefix=const.MODEL_PREFIX,
user_defined_symbols=user_defined_symbols,
vocab_size=self.vocab_size,
hard_vocab_limit=False,
max_sentence_length=self.max_line_line,
input_sentence_size=self.pretrain_sentence_count,
shuffle_input_sentence=True,
character_coverage=self.character_coverage,
)
try:
spm.SentencePieceTrainer.Train(
input=self.config.training_data_path,
model_prefix=const.MODEL_PREFIX,
user_defined_symbols=user_defined_symbols,
vocab_size=self.vocab_size,
hard_vocab_limit=False,
max_sentence_length=self.max_line_line,
input_sentence_size=self.pretrain_sentence_count,
shuffle_input_sentence=True,
character_coverage=self.character_coverage,
)
except RuntimeError as e:
if "Vocabulary size is smaller than required_chars" in str(e):
raise VocabSizeTooSmall(
"The value for the `vocab_size` parameter is too small for the "
"provided dataset. Please increase it or set it to `0` to use "
"character-based tokenization."
) from e

raise e

# The training automatically saves to disk,
# so we have to now load it back in after we move
Expand Down Expand Up @@ -535,7 +556,6 @@ def _add_column_markers(


class SentencePieceColumnTokenizerTrainer(SentencePieceTokenizerTrainer):

_col_pattern: str
_col_symbols: Set[str]

Expand Down
20 changes: 20 additions & 0 deletions tests/test_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

from gretel_synthetics.config import BaseConfig
from gretel_synthetics.tokenizers import VocabSizeTooSmall


class SimpleConfig(BaseConfig):
Expand Down Expand Up @@ -377,3 +378,22 @@ def test_sp_field_delim(input_data_path, tmpdir):

# Check the factory
assert isinstance(tok.tokenizer_from_model_dir(tmpdir), tok.SentencePieceTokenizer)


def test_vocab_size_too_small(input_data_path, tmpdir):
config = SimpleConfig(
input_data_path=input_data_path,
checkpoint_dir=tmpdir,
field_delimiter=",",
)
trainer = tok.SentencePieceTokenizerTrainer(config=config, vocab_size=5)
line_iter = trainer.annotate_data()

line_one = next(line_iter)
assert (
line_one
== "Once upon a midnight dreary<d> while I pondered<d> weak and weary<d><n>\n"
)

with pytest.raises(VocabSizeTooSmall):
trainer.train()

0 comments on commit a63c68d

Please sign in to comment.