Skip to content

Commit

Permalink
PROD-108: col tokenizer
Browse files Browse the repository at this point in the history
* Initial v2 SP column tokenizer

Co-authored-by: John Myers <john@gretel.ai>
GitOrigin-RevId: fddfb9b3433ddc1e03ae14e31e71fab87fc3abbb
  • Loading branch information
johntmyers and John Myers committed Jun 24, 2022
1 parent 2c008c0 commit fbcbfea
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 10 deletions.
134 changes: 130 additions & 4 deletions src/gretel_synthetics/tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,22 @@
"""
import json
import logging
import re
import shutil

from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, TYPE_CHECKING
from typing import (
Any,
Dict,
Iterator,
List,
Optional,
Pattern,
Set,
Tuple,
TYPE_CHECKING,
)

import cloudpickle
import gretel_synthetics.const as const
Expand Down Expand Up @@ -207,12 +218,14 @@ class BaseTokenizer(Base):

field_delimiter: Optional[str] = None
field_delimiter_token: Optional[str] = None
settings: Optional[dict] = None

def __init__(self, model_data: Any, model_dir: str):
self._model = model_data
self._model_dir = model_dir

self._load_delimiter_data()
self._load_settings_dict()
super().__init__()

def _load_delimiter_data(self):
Expand All @@ -229,6 +242,16 @@ def _load_delimiter_data(self):
self.field_delimiter = params_dict.get(FIELD_DELIM, None)
self.field_delimiter_token = params_dict.get(FIELD_DELIM_TOKEN, None)

def _load_settings_dict(self):
"""Load the dictionary of custom settings that were saved out
when creating the tokenizer
"""
params_file = Path(self._model_dir) / self.settings_fname
if not params_file.is_file():
self.settings = {}

self.settings = json.loads(params_file.read_text())

@classmethod
@abstractmethod
def load(cls, model_dir: str):
Expand Down Expand Up @@ -430,23 +453,32 @@ def __init__(

super().__init__(**kwargs)

def _no_delim_line(self, line: str) -> str:
return line.strip() + self.newline_str + "\n"

def _annotate_training_line(self, line: str):
if self.config.field_delimiter is not None:
line = line.strip().replace(
self.config.field_delimiter, self.config.field_delimiter_token
)
line += f"{self.newline_str}\n"
else:
line = line.strip() + self.newline_str + "\n"
line = self._no_delim_line(line)

return line

def _train(self):
def _train(self, extra_symbols: Optional[List[str]] = None):
if extra_symbols is None:
extra_symbols = []
user_defined_symbols = [
self.newline_str,
self.config.field_delimiter_token,
] + extra_symbols
logging.info("Training SentencePiece tokenizer")
spm.SentencePieceTrainer.Train(
input=self.config.training_data_path,
model_prefix=const.MODEL_PREFIX,
user_defined_symbols=[self.newline_str, self.config.field_delimiter_token],
user_defined_symbols=user_defined_symbols,
vocab_size=self.vocab_size,
hard_vocab_limit=False,
max_sentence_length=self.max_line_line,
Expand Down Expand Up @@ -475,6 +507,65 @@ def _get_save_settings(self):
}


_DEFAULT_COL_PATTERN = "<col{}>"
_COL_PATTERN_RE = re.compile(r"^<[a-zA-z]{3,5}\{\}>$")
COL_PATTERN = "col_pattern"


def _add_column_markers(
delim: str,
col_pattern: str,
newline_str: str,
line: str,
ignore_newline: bool = False,
) -> Tuple[str, Set[str]]:
symbols = set() # track every column token we create i.e. <colN>
line = line.strip()
parts = line.split(delim)
new_parts = []
for idx, value in enumerate(parts):
symbol = col_pattern.format(idx)
new_parts.append(symbol)
new_parts.append(value)
symbols.add(symbol)
if not ignore_newline:
new_parts.extend([newline_str, "\n"])
return "".join(new_parts), symbols


class SentencePieceColumnTokenizerTrainer(SentencePieceTokenizerTrainer):

_col_pattern: str
_col_symbols: Set[str]

def __init__(self, col_pattern: str = _DEFAULT_COL_PATTERN, **kwargs):
if not _COL_PATTERN_RE.match(col_pattern):
raise ValueError(
f"col_pattern must satisfy the following pattern: {_COL_PATTERN_RE.pattern}"
)
self._col_pattern = col_pattern
self._col_symbols = set()
super().__init__(**kwargs)

def _annotate_training_line(self, line: str) -> str:
if self.config.field_delimiter is not None:
new_line, symbols = _add_column_markers(
self.config.field_delimiter, self._col_pattern, self.newline_str, line
)
self._col_symbols = self._col_symbols.union(symbols)
return new_line
else:
return self._no_delim_line(line)

def _get_save_settings(self) -> dict:
curr_dict = super()._get_save_settings()
curr_dict[COL_PATTERN] = self._col_pattern
return curr_dict

def _train(self):
super()._train(extra_symbols=sorted(list(self._col_symbols)))


def _log_sample_data(model_dir: str, sp: spm.SentencePieceProcessor):
training_data_path = Path(model_dir) / const.TRAINING_DATA
if not training_data_path.is_file():
Expand Down Expand Up @@ -545,12 +636,47 @@ def detokenize_delimiter(self, line: str) -> str:
return line.replace(self.field_delimiter_token, self.field_delimiter)


class SentencePieceColumnTokenizer(SentencePieceTokenizer):
_col_pattern: str
_col_pattern_re: Pattern

def __init__(self, sp: spm.SentencePieceProcessor, model_dir: str):
super().__init__(sp, model_dir)
self._col_pattern = self.settings.get(COL_PATTERN)
self._col_pattern_re = re.compile(self._col_pattern.replace("{}", "\\d+"))

def _restore_delims(self, line_with_tokens: str) -> str:
# NOTE: Since with this tokenizer the <colN> values are _before_
# the actual values, we skip the first string in the split parts
# because it will just be an empty string, so something like
# <col0>foo<col1>bar will split into ['', 'foo', 'bar]
parts = self._col_pattern_re.split(line_with_tokens)[1:]
return self.field_delimiter.join(parts)

def _replace_decoded_tokens(self, decoded_line: str) -> str:
return self._restore_delims(decoded_line)

def tokenize_delimiter(self, line: str) -> str:
new_line, _ = _add_column_markers(
self.field_delimiter,
self._col_pattern,
self.newline_str, # ignored
line,
ignore_newline=True,
)
return new_line

def detokenize_delimiter(self, line: str) -> str:
return self._replace_decoded_tokens(line)


##########
# Factory
##########

TOK_MAP = {
SentencePieceTokenizerTrainer.__name__: SentencePieceTokenizer,
SentencePieceColumnTokenizerTrainer.__name__: SentencePieceColumnTokenizer,
CharTokenizerTrainer.__name__: CharTokenizer,
}

Expand Down
6 changes: 3 additions & 3 deletions tests-integration/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,11 @@ def scooter_val(line):

SIMPLE_MODELS = [
(
"https://gretel-public-website.s3-us-west-2.amazonaws.com/tests/synthetics/models/scooter-simple-sp-0-14.tar.gz",
"https://gretel-public-website.s3-us-west-2.amazonaws.com/tests/synthetics/models/scooter-simple-sp-0-14.tar.gz", # noqa
scooter_val,
), # noqa
(
"https://gretel-public-website.s3-us-west-2.amazonaws.com/tests/synthetics/models/scooter-simple-char-0-15.tar.gz",
"https://gretel-public-website.s3-us-west-2.amazonaws.com/tests/synthetics/models/scooter-simple-char-0-15.tar.gz", # noqa
scooter_val,
), # noqa
]
Expand All @@ -251,7 +251,7 @@ def test_generate_simple(model_path, validator_fn, tmp_path):
"model_path,seed",
[
(
"https://gretel-public-website.s3-us-west-2.amazonaws.com/tests/synthetics/models/safecast-batch-sp-0-14.tar.gz",
"https://gretel-public-website.s3-us-west-2.amazonaws.com/tests/synthetics/models/safecast-batch-sp-0-14.tar.gz", # noqa
{"payload.service_handler": "i-051a2a353509414f0"},
) # noqa
],
Expand Down
12 changes: 9 additions & 3 deletions tests-integration/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from gretel_synthetics.tokenizers import (
BaseTokenizerTrainer,
CharTokenizerTrainer,
SentencePieceColumnTokenizerTrainer,
SentencePieceTokenizerTrainer,
)
from gretel_synthetics.train import EpochState
Expand Down Expand Up @@ -128,20 +129,25 @@ def test_train_batch_char_tok(train_df, tmp_path):
assert syn_df.shape[0] == _tok_gen_count


def test_train_batch_sp_tok(train_df, tmp_path):
@pytest.mark.parametrize(
"tok_class", [SentencePieceTokenizerTrainer, SentencePieceColumnTokenizerTrainer]
)
def test_train_batch_sp_tok(train_df, tmp_path, tok_class):
config = TensorFlowConfig(
epochs=5,
field_delimiter=",",
checkpoint_dir=tmp_path,
input_data_path=PATH_HOLDER,
learning_rate=0.01,
)
tokenizer = SentencePieceTokenizerTrainer(vocab_size=10000, config=config)
tokenizer = tok_class(vocab_size=10000, config=config)
batcher = DataFrameBatch(df=train_df, config=config, tokenizer=tokenizer)
batcher.create_training_data()
batcher.train_all_batches()

batcher.generate_all_batch_lines(num_lines=_tok_gen_count, max_invalid=5000)
batcher.generate_all_batch_lines(
num_lines=_tok_gen_count, max_invalid=5000, parallelism=1
)
syn_df = batcher.batches_to_df()
assert syn_df.shape[0] == _tok_gen_count

Expand Down
92 changes: 92 additions & 0 deletions tests/test_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,98 @@ def test_sp(input_data_path, tmpdir):
)


def test_sp_column_tok(input_data_path, tmpdir):
# We can only use valid token patterns
with pytest.raises(ValueError):
tok.SentencePieceColumnTokenizerTrainer(col_pattern="<nope2{}>")

config = SimpleConfig(
input_data_path=input_data_path, checkpoint_dir=tmpdir, field_delimiter=","
)
trainer = tok.SentencePieceColumnTokenizerTrainer(config=config)
line_iter = trainer.annotate_data()
line_one = next(line_iter)
assert (
line_one
== "<col0>Once upon a midnight dreary<col1> while I pondered<col2> weak and weary<col3><n>\n"
)

trainer.train()
assert len(trainer._col_symbols) == 4
tokenizer = tok.SentencePieceColumnTokenizer.load(tmpdir)

# Validate that our column pattern was saved out and restored
assert tokenizer._col_pattern == tok._DEFAULT_COL_PATTERN

ids = [
9,
5,
43,
57,
11,
9,
14,
38,
13,
17,
19,
16,
20,
19,
25,
23,
18,
9,
16,
29,
34,
10,
6,
55,
44,
12,
11,
9,
26,
9,
38,
16,
50,
16,
7,
52,
65,
13,
31,
52,
29,
10,
8,
3,
]
assert (
tokenizer.encode_to_ids(
"<col0>Once upon a midnight dreary<col1> while I pondered<col2> weak and weary<col3><n>\n"
)
== ids
)

assert (
tokenizer.decode_from_ids(ids)
== "Once upon a midnight dreary, while I pondered, weak and weary,<n>"
)
assert isinstance(
tok.tokenizer_from_model_dir(tmpdir), tok.SentencePieceColumnTokenizer
)

substring = "this is,a test,"
check = tokenizer.tokenize_delimiter(substring)
assert check == "<col0>this is<col1>a test<col2>"

check2 = tokenizer.detokenize_delimiter(check)
assert check2 == "this is,a test,"


def test_sp_field_delim(input_data_path, tmpdir):
config = SimpleConfig(
input_data_path=input_data_path, checkpoint_dir=tmpdir, field_delimiter=","
Expand Down

0 comments on commit fbcbfea

Please sign in to comment.