Skip to content

Commit

Permalink
code review feedback: tokenization utils public and abstract tokeniza…
Browse files Browse the repository at this point in the history
…tion normalization
  • Loading branch information
jeswan committed Mar 19, 2021
1 parent 391bd59 commit 68cc8b5
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 17 deletions.
19 changes: 12 additions & 7 deletions jiant/proj/main/modeling/primary.py
Expand Up @@ -21,9 +21,9 @@

from jiant.utils.tokenization_utils import bow_tag_tokens
from jiant.utils.tokenization_utils import eow_tag_tokens
from jiant.utils.tokenization_utils import _process_bytebpe_tokens
from jiant.utils.tokenization_utils import _process_wordpiece_tokens
from jiant.utils.tokenization_utils import _process_sentencepiece_tokens
from jiant.utils.tokenization_utils import process_bytebpe_tokens
from jiant.utils.tokenization_utils import process_wordpiece_tokens
from jiant.utils.tokenization_utils import process_sentencepiece_tokens


@dataclass
Expand Down Expand Up @@ -175,6 +175,11 @@ def __init__(self, baseObject):
)
self.__dict__ = baseObject.__dict__

@classmethod
@abc.abstractmethod
def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization):
pass

@abc.abstractmethod
def get_mlm_weights_dict(self, weights_dict):
pass
Expand Down Expand Up @@ -214,7 +219,7 @@ def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenizat
if tokenizer.init_kwargs.get("do_lower_case", False):
space_tokenization = [token.lower() for token in space_tokenization]
modifed_space_tokenization = bow_tag_tokens(space_tokenization)
modifed_target_tokenization = _process_wordpiece_tokens(target_tokenization)
modifed_target_tokenization = process_wordpiece_tokens(target_tokenization)

return modifed_space_tokenization, modifed_target_tokenization

Expand Down Expand Up @@ -256,7 +261,7 @@ def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenizat
"""See tokenization_normalization.py for details"""
modifed_space_tokenization = bow_tag_tokens(space_tokenization)
modifed_target_tokenization = ["Ġ" + target_tokenization[0]] + target_tokenization[1:]
modifed_target_tokenization = _process_bytebpe_tokens(modifed_target_tokenization)
modifed_target_tokenization = process_bytebpe_tokens(modifed_target_tokenization)

return modifed_space_tokenization, modifed_target_tokenization

Expand Down Expand Up @@ -296,7 +301,7 @@ def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenizat
"""See tokenization_normalization.py for details"""
space_tokenization = [token.lower() for token in space_tokenization]
modifed_space_tokenization = bow_tag_tokens(space_tokenization)
modifed_target_tokenization = _process_sentencepiece_tokens(target_tokenization)
modifed_target_tokenization = process_sentencepiece_tokens(target_tokenization)

return modifed_space_tokenization, modifed_target_tokenization

Expand Down Expand Up @@ -366,7 +371,7 @@ def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenizat
"""See tokenization_normalization.py for details"""
space_tokenization = [token.lower() for token in space_tokenization]
modifed_space_tokenization = bow_tag_tokens(space_tokenization)
modifed_target_tokenization = _process_sentencepiece_tokens(target_tokenization)
modifed_target_tokenization = process_sentencepiece_tokens(target_tokenization)

return modifed_space_tokenization, modifed_target_tokenization

Expand Down
18 changes: 9 additions & 9 deletions jiant/utils/tokenization_utils.py
Expand Up @@ -13,35 +13,35 @@ def eow_tag_tokens(tokens: Sequence[str], eow_tag: str = "</w>"):
return [t + eow_tag for t in tokens]


def _process_wordpiece_tokens(tokens: Sequence[str]):
return [_process_wordpiece_token_for_alignment(token) for token in tokens]
def process_wordpiece_tokens(tokens: Sequence[str]):
return [process_wordpiece_token_for_alignment(token) for token in tokens]


def _process_sentencepiece_tokens(tokens: Sequence[str]):
return [_process_sentencepiece_token_for_alignment(token) for token in tokens]
def process_sentencepiece_tokens(tokens: Sequence[str]):
return [process_sentencepiece_token_for_alignment(token) for token in tokens]


def _process_bytebpe_tokens(tokens: Sequence[str]):
return [_process_bytebpe_token_for_alignment(token) for token in tokens]
def process_bytebpe_tokens(tokens: Sequence[str]):
return [process_bytebpe_token_for_alignment(token) for token in tokens]


def _process_wordpiece_token_for_alignment(t):
def process_wordpiece_token_for_alignment(t):
"""Add word boundary markers, removes token prefix (no-space meta-symbol — '##' for BERT)."""
if t.startswith("##"):
return re.sub(r"^##", "", t)
else:
return "<w>" + t


def _process_sentencepiece_token_for_alignment(t):
def process_sentencepiece_token_for_alignment(t):
"""Add word boundary markers, removes token prefix (space meta-symbol)."""
if t.startswith("▁"):
return "<w>" + re.sub(r"^▁", "", t)
else:
return t


def _process_bytebpe_token_for_alignment(t):
def process_bytebpe_token_for_alignment(t):
"""Add word boundary markers, removes token prefix (space meta-symbol)."""
if t.startswith("Ġ"):
return "<w>" + re.sub(r"^Ġ", "", t)
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_tokenization_normalization.py
Expand Up @@ -145,7 +145,7 @@ def test_process_bytebpe_token_sequence():
"Ġrules",
".",
]
adjusted_bytebpe_tokens = tu._process_bytebpe_tokens(original_bytebpe_tokens)
adjusted_bytebpe_tokens = tu.process_bytebpe_tokens(original_bytebpe_tokens)
assert adjusted_bytebpe_tokens == expected_adjusted_bytebpe_tokens


Expand Down

0 comments on commit 68cc8b5

Please sign in to comment.