Skip to content

Commit

Permalink
Sp updates (#23)
Browse files Browse the repository at this point in the history
* Upgrade SentencePiece and enable "max_sentence_length" option

* version bump

Co-authored-by: John Myers <john@gretel.ai>
  • Loading branch information
johntmyers and John Myers committed Jun 3, 2020
1 parent 5edeff8 commit 508276d
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 3 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.9.2
0.9.3
3 changes: 2 additions & 1 deletion examples/synthetic_records.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
" max_lines=0, # use max_lines of training data. Set to 0 (zero) to on all lines in dataset\n",
" epochs=15, # 15-50 epochs with GPU for best performance\n",
" vocab_size=15000, # tokenizer model vocabulary size\n",
" max_line_len=2048, # the max line length SentencePiece will consider\n",
" character_coverage=1.0, # tokenizer model character coverage percent\n",
" gen_chars=0, # the maximum number of characters possible per-generated line of text\n",
" gen_lines=100, # the number of generated text lines\n",
Expand All @@ -61,7 +62,7 @@
" field_delimiter=\",\", # if the training text is structured\n",
" # overwrite=True, # enable this if you want to keep training models to the same checkpoint location\n",
" input_data_path=\"https://gretel-public-website.s3-us-west-2.amazonaws.com/datasets/uber_scooter_rides_1day.csv\" # filepath or S3\n",
")"
")\n"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
packages=find_packages('src'),
install_requires=[
'tensorflow_privacy==0.2.2',
'sentencepiece==0.1.85',
'sentencepiece==0.1.91',
'smart_open==1.10.0',
'tqdm<5.0',
'pandas==1.0.3',
Expand Down
1 change: 1 addition & 0 deletions src/gretel_synthetics/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class _BaseConfig:
rnn_units: int = 256
dropout_rate: float = 0.2
rnn_initializer: str = "glorot_uniform"
max_line_len: int = 2048

# Input data configs
field_delimiter: Optional[str] = None
Expand Down
11 changes: 11 additions & 0 deletions src/gretel_synthetics/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,24 @@ def _train_tokenizer(store: _BaseConfig) -> spm.SentencePieceProcessor:
Trains SentencePiece tokenizer on training data
"""
logging.info("Training SentencePiece tokenizer")
spm.SentencePieceTrainer.Train(
input=store.training_data,
model_prefix=store.tokenizer_prefix,
user_defined_symbols=["<n>", store.field_delimiter_token],
vocab_size=store.vocab_size,
hard_vocab_limit=False,
max_sentence_length=store.max_line_len,
character_coverage=store.character_coverage
)
"""
spm.SentencePieceTrainer.Train(
f'--input={store.training_data} '
f'--model_prefix={store.tokenizer_prefix} '
f'--user_defined_symbols=<n>,{store.field_delimiter_token} '
f'--vocab_size={store.vocab_size} '
f'--hard_vocab_limit=false '
f'--character_coverage={store.character_coverage}')
"""
_move_tokenizer_model(store)

sp = spm.SentencePieceProcessor()
Expand Down
1 change: 1 addition & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_local_config_settings(mkdir):
"gen_temp": 1.0,
"gen_chars": 0,
"gen_lines": 500,
"max_line_len": 2048,
"save_all_checkpoints": True,
"checkpoint_dir": "foo",
"field_delimiter": None,
Expand Down

0 comments on commit 508276d

Please sign in to comment.