From 21e773600f26e4c343373140ac58e3316cd1cce3 Mon Sep 17 00:00:00 2001 From: Jesse Swanson Date: Thu, 25 Feb 2021 14:00:04 -0500 Subject: [PATCH 1/4] move model specific tokenization logic to JiantTransformerModels --- jiant/proj/main/modeling/primary.py | 78 +++++++++++++--- jiant/shared/model_resolution.py | 38 +++++--- jiant/utils/python/datastructures.py | 41 +++++++++ jiant/utils/tokenization_normalization.py | 89 ++++--------------- jiant/utils/tokenization_utils.py | 49 ++++++++++ .../utils/test_tokenization_normalization.py | 7 +- 6 files changed, 205 insertions(+), 97 deletions(-) create mode 100644 jiant/utils/tokenization_utils.py diff --git a/jiant/proj/main/modeling/primary.py b/jiant/proj/main/modeling/primary.py index da504ab80..d6f1bd9c0 100644 --- a/jiant/proj/main/modeling/primary.py +++ b/jiant/proj/main/modeling/primary.py @@ -1,6 +1,7 @@ import abc from dataclasses import dataclass + from typing import Any from typing import Callable from typing import Dict @@ -9,14 +10,20 @@ import torch import torch.nn as nn -import jiant.tasks as tasks import jiant.utils.python.strings as strings +from jiant.tasks.core import BatchMixin +from jiant.tasks.core import FeaturizationSpec +from jiant.tasks.core import Task +from jiant.proj.main.components.outputs import construct_output_from_dict from jiant.proj.main.modeling.taskmodels import Taskmodel from jiant.shared.model_resolution import ModelArchitectures -from jiant.tasks.core import FeaturizationSpec -from jiant.proj.main.components.outputs import construct_output_from_dict +from jiant.utils.tokenization_utils import bow_tag_tokens +from jiant.utils.tokenization_utils import eow_tag_tokens +from jiant.utils.tokenization_utils import _process_bytebpe_tokens +from jiant.utils.tokenization_utils import _process_wordpiece_tokens +from jiant.utils.tokenization_utils import _process_sentencepiece_tokens @dataclass @@ -29,7 +36,7 @@ class JiantModelOutput: class JiantModel(nn.Module): def __init__( self, - task_dict: Dict[str, tasks.Task], + task_dict: Dict[str, Task], encoder: nn.Module, taskmodels_dict: Dict[str, Taskmodel], task_to_taskmodel_map: Dict[str, str], @@ -42,15 +49,15 @@ def __init__( self.task_to_taskmodel_map = task_to_taskmodel_map self.tokenizer = tokenizer - def forward(self, batch: tasks.BatchMixin, task: tasks.Task, compute_loss: bool = False): + def forward(self, batch: BatchMixin, task: Task, compute_loss: bool = False): """Calls to this forward method are delegated to the forward of the appropriate taskmodel. When JiantModel forward is called, the task name from the task argument is used as a key to select the appropriate submodule/taskmodel, and that taskmodel's forward is called. Args: - batch (tasks.BatchMixin): model input. - task (tasks.Task): task to which to delegate the forward call. + batch (BatchMixin): model input. + task (Task): task to which to delegate the forward call. compute_loss (bool): whether to calculate and return the loss. Returns: @@ -74,20 +81,20 @@ def forward(self, batch: tasks.BatchMixin, task: tasks.Task, compute_loss: bool def wrap_jiant_forward( jiant_model: Union[JiantModel, nn.DataParallel], - batch: tasks.BatchMixin, - task: tasks.Task, + batch: BatchMixin, + task: Task, compute_loss: bool = False, ): """Wrapper to repackage model inputs using dictionaries for compatibility with DataParallel. - Wrapper that converts batches (type tasks.BatchMixin) to dictionaries before delegating to + Wrapper that converts batches (type BatchMixin) to dictionaries before delegating to JiantModel's forward method, and then converts the resulting model output dict into the appropriate model output dataclass. Args: jiant_model (Union[JiantModel, nn.DataParallel]): - batch (tasks.BatchMixin): model input batch. - task (tasks.Task): Task object passed for access in the taskmodel. + batch (BatchMixin): model input batch. + task (Task): Task object passed for access in the taskmodel. compute_loss (bool): True if loss should be computed, False otherwise. Returns: @@ -201,6 +208,16 @@ class JiantBertModel(JiantTransformersModel): def __init__(self, baseObject): super().__init__(baseObject) + @classmethod + def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization): + """See tokenization_normalization.py for details""" + if tokenizer.init_kwargs.get("do_lower_case", False): + space_tokenization = [token.lower() for token in space_tokenization] + modifed_space_tokenization = bow_tag_tokens(space_tokenization) + modifed_target_tokenization = _process_wordpiece_tokens(target_tokenization) + + return modifed_space_tokenization, modifed_target_tokenization + def get_feat_spec(self, max_seq_length): return FeaturizationSpec( max_seq_length=max_seq_length, @@ -234,6 +251,15 @@ class JiantRobertaModel(JiantTransformersModel): def __init__(self, baseObject): super().__init__(baseObject) + @classmethod + def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization): + """See tokenization_normalization.py for details""" + modifed_space_tokenization = bow_tag_tokens(space_tokenization) + modifed_target_tokenization = ["Ġ" + target_tokenization[0]] + target_tokenization[1:] + modifed_target_tokenization = _process_bytebpe_tokens(modifed_target_tokenization) + + return modifed_space_tokenization, modifed_target_tokenization + def get_mlm_weights_dict(self, weights_dict): mlm_weights_dict = { strings.remove_prefix(k, "lm_head."): v for k, v in weights_dict.items() @@ -265,6 +291,15 @@ class JiantXLMRobertaModel(JiantTransformersModel): def __init__(self, baseObject): super().__init__(baseObject) + @classmethod + def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization): + """See tokenization_normalization.py for details""" + space_tokenization = [token.lower() for token in space_tokenization] + modifed_space_tokenization = bow_tag_tokens(space_tokenization) + modifed_target_tokenization = _process_sentencepiece_tokens(target_tokenization) + + return modifed_space_tokenization, modifed_target_tokenization + def get_feat_spec(self, max_seq_length): # XLM-RoBERTa is weird # token 0 = '' which is the cls_token @@ -296,6 +331,16 @@ class JiantXLMModel(JiantTransformersModel): def __init__(self, baseObject): super().__init__(baseObject) + @classmethod + def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization): + """See tokenization_normalization.py for details""" + if tokenizer.init_kwargs.get("do_lowercase_and_remove_accent", False): + space_tokenization = [token.lower() for token in space_tokenization] + modifed_space_tokenization = eow_tag_tokens(space_tokenization) + modifed_target_tokenization = target_tokenization + + return modifed_space_tokenization, modifed_target_tokenization + def get_feat_spec(self, max_seq_length): return FeaturizationSpec( max_seq_length=max_seq_length, @@ -316,6 +361,15 @@ class JiantAlbertModel(JiantTransformersModel): def __init__(self, baseObject): super().__init__(baseObject) + @classmethod + def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization): + """See tokenization_normalization.py for details""" + space_tokenization = [token.lower() for token in space_tokenization] + modifed_space_tokenization = bow_tag_tokens(space_tokenization) + modifed_target_tokenization = _process_sentencepiece_tokens(target_tokenization) + + return modifed_space_tokenization, modifed_target_tokenization + def get_mlm_weights_dict(self, weights_dict): mlm_weights_dict = { strings.remove_prefix(k, "predictions."): v for k, v in weights_dict.items() diff --git a/jiant/shared/model_resolution.py b/jiant/shared/model_resolution.py index 21067bd07..b4eab62fe 100644 --- a/jiant/shared/model_resolution.py +++ b/jiant/shared/model_resolution.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from enum import Enum +from jiant.utils.python.datastructures import BiDict import transformers @@ -26,16 +27,18 @@ class ModelClassSpec: model_class: type -TOKENIZER_CLASS_DICT = { - ModelArchitectures.BERT: transformers.BertTokenizer, - ModelArchitectures.XLM: transformers.XLMTokenizer, - ModelArchitectures.ROBERTA: transformers.RobertaTokenizer, - ModelArchitectures.XLM_ROBERTA: transformers.XLMRobertaTokenizer, - ModelArchitectures.ALBERT: transformers.AlbertTokenizer, - ModelArchitectures.BART: transformers.BartTokenizer, - ModelArchitectures.MBART: transformers.MBartTokenizer, - ModelArchitectures.ELECTRA: transformers.ElectraTokenizer, -} +TOKENIZER_CLASS_DICT = BiDict( + { + ModelArchitectures.BERT: transformers.BertTokenizer, + ModelArchitectures.XLM: transformers.XLMTokenizer, + ModelArchitectures.ROBERTA: transformers.RobertaTokenizer, + ModelArchitectures.XLM_ROBERTA: transformers.XLMRobertaTokenizer, + ModelArchitectures.ALBERT: transformers.AlbertTokenizer, + ModelArchitectures.BART: transformers.BartTokenizer, + ModelArchitectures.MBART: transformers.MBartTokenizer, + ModelArchitectures.ELECTRA: transformers.ElectraTokenizer, + } +) def resolve_tokenizer_class(model_type): @@ -51,6 +54,21 @@ def resolve_tokenizer_class(model_type): return TOKENIZER_CLASS_DICT[ModelArchitectures(model_type)] +def resolve_model_arch_tokenizer(tokenizer): + """Get the model architecture for a given tokenizer. + + Args: + tokenizer + + Returns: + ModelArchitecture + + """ + print(TOKENIZER_CLASS_DICT.inverse) + assert len(TOKENIZER_CLASS_DICT.inverse[tokenizer.__class__]) == 1 + return TOKENIZER_CLASS_DICT.inverse[tokenizer.__class__][0] + + def resolve_is_lower_case(tokenizer): if isinstance(tokenizer, transformers.BertTokenizer): return tokenizer.basic_tokenizer.do_lower_case diff --git a/jiant/utils/python/datastructures.py b/jiant/utils/python/datastructures.py index 8707261ab..eefc6e902 100644 --- a/jiant/utils/python/datastructures.py +++ b/jiant/utils/python/datastructures.py @@ -279,3 +279,44 @@ def get_maps(self) -> Tuple[Dict, Dict]: """ return self.a_to_b, self.b_to_a + + +class BiDict(dict): + """Maintains bidirectional dict + + Example: + bd = BiDict({'a': 1, 'b': 2}) + print(bd) # {'a': 1, 'b': 2} + print(bd.inverse) # {1: ['a'], 2: ['b']} + bd['c'] = 1 # Now two keys have the same value (= 1) + print(bd) # {'a': 1, 'c': 1, 'b': 2} + print(bd.inverse) # {1: ['a', 'c'], 2: ['b']} + del bd['c'] + print(bd) # {'a': 1, 'b': 2} + print(bd.inverse) # {1: ['a'], 2: ['b']} + del bd['a'] + print(bd) # {'b': 2} + print(bd.inverse) # {2: ['b']} + bd['b'] = 3 + print(bd) # {'b': 3} + print(bd.inverse) # {2: [], 3: ['b']} + + """ + + def __init__(self, *args, **kwargs): + super(BiDict, self).__init__(*args, **kwargs) + self.inverse = {} + for key, value in self.items(): + self.inverse.setdefault(value, []).append(key) + + def __setitem__(self, key, value): + if key in self: + self.inverse[self[key]].remove(key) + super(BiDict, self).__setitem__(key, value) + self.inverse.setdefault(value, []).append(key) + + def __delitem__(self, key): + self.inverse.setdefault(self[key], []).remove(key) + if self[key] in self.inverse and not self.inverse[self[key]]: + del self.inverse[self[key]] + super(BiDict, self).__delitem__(key) diff --git a/jiant/utils/tokenization_normalization.py b/jiant/utils/tokenization_normalization.py index 8a07be2c9..95c59ac5e 100644 --- a/jiant/utils/tokenization_normalization.py +++ b/jiant/utils/tokenization_normalization.py @@ -8,11 +8,12 @@ """ -import re import transformers from typing import Sequence from jiant.utils.testing import utils as test_utils +from jiant.shared.model_resolution import resolve_model_arch_tokenizer +from jiant.proj.main.modeling.primary import JiantTransformersModelFactory def normalize_tokenizations( @@ -49,80 +50,24 @@ def normalize_tokenizations( if len(space_tokenization) == 0 or len(target_tokenization) == 0: raise ValueError("Empty token sequence.") - if isinstance(tokenizer, transformers.BertTokenizer): - if tokenizer.init_kwargs.get("do_lower_case", False): - space_tokenization = [token.lower() for token in space_tokenization] - modifed_space_tokenization = bow_tag_tokens(space_tokenization) - modifed_target_tokenization = _process_wordpiece_tokens(target_tokenization) - elif isinstance(tokenizer, transformers.XLMTokenizer): - if tokenizer.init_kwargs.get("do_lowercase_and_remove_accent", False): - space_tokenization = [token.lower() for token in space_tokenization] - modifed_space_tokenization = eow_tag_tokens(space_tokenization) - modifed_target_tokenization = target_tokenization - elif isinstance(tokenizer, transformers.RobertaTokenizer): - modifed_space_tokenization = bow_tag_tokens(space_tokenization) - modifed_target_tokenization = ["Ġ" + target_tokenization[0]] + target_tokenization[1:] - modifed_target_tokenization = _process_bytebpe_tokens(modifed_target_tokenization) - elif isinstance(tokenizer, (transformers.AlbertTokenizer, transformers.XLMRobertaTokenizer)): - space_tokenization = [token.lower() for token in space_tokenization] - modifed_space_tokenization = bow_tag_tokens(space_tokenization) - modifed_target_tokenization = _process_sentencepiece_tokens(target_tokenization) - else: - if test_utils.is_pytest(): - from jiant.utils.testing.tokenizer import SimpleSpaceTokenizer - - if isinstance(tokenizer, SimpleSpaceTokenizer): - return space_tokenization, target_tokenization - raise ValueError("Tokenizer not supported.") + if test_utils.is_pytest(): + from jiant.utils.testing.tokenizer import SimpleSpaceTokenizer + + if isinstance(tokenizer, SimpleSpaceTokenizer): + return space_tokenization, target_tokenization + + model_arch = resolve_model_arch_tokenizer(tokenizer) + print(model_arch) + jiant_transformer_model_class = JiantTransformersModelFactory.get_registry()[model_arch] + ( + modifed_space_tokenization, + modifed_target_tokenization, + ) = jiant_transformer_model_class.normalize_tokenizations( + tokenizer, space_tokenization, target_tokenization + ) # safety check: if normalization changed sequence length, alignment is likely to break. assert len(modifed_space_tokenization) == len(space_tokenization) assert len(modifed_target_tokenization) == len(target_tokenization) return modifed_space_tokenization, modifed_target_tokenization - - -def bow_tag_tokens(tokens: Sequence[str], bow_tag: str = ""): - """Applies a beginning of word (BoW) marker to every token in the tokens sequence.""" - return [bow_tag + t for t in tokens] - - -def eow_tag_tokens(tokens: Sequence[str], eow_tag: str = ""): - """Applies a end of word (EoW) marker to every token in the tokens sequence.""" - return [t + eow_tag for t in tokens] - - -def _process_wordpiece_tokens(tokens: Sequence[str]): - return [_process_wordpiece_token_for_alignment(token) for token in tokens] - - -def _process_sentencepiece_tokens(tokens: Sequence[str]): - return [_process_sentencepiece_token_for_alignment(token) for token in tokens] - - -def _process_bytebpe_tokens(tokens: Sequence[str]): - return [_process_bytebpe_token_for_alignment(token) for token in tokens] - - -def _process_wordpiece_token_for_alignment(t): - """Add word boundary markers, removes token prefix (no-space meta-symbol — '##' for BERT).""" - if t.startswith("##"): - return re.sub(r"^##", "", t) - else: - return "" + t - - -def _process_sentencepiece_token_for_alignment(t): - """Add word boundary markers, removes token prefix (space meta-symbol).""" - if t.startswith("▁"): - return "" + re.sub(r"^▁", "", t) - else: - return t - - -def _process_bytebpe_token_for_alignment(t): - """Add word boundary markers, removes token prefix (space meta-symbol).""" - if t.startswith("Ġ"): - return "" + re.sub(r"^Ġ", "", t) - else: - return t diff --git a/jiant/utils/tokenization_utils.py b/jiant/utils/tokenization_utils.py new file mode 100644 index 000000000..94041e5ec --- /dev/null +++ b/jiant/utils/tokenization_utils.py @@ -0,0 +1,49 @@ +import re + +from typing import Sequence + + +def bow_tag_tokens(tokens: Sequence[str], bow_tag: str = ""): + """Applies a beginning of word (BoW) marker to every token in the tokens sequence.""" + return [bow_tag + t for t in tokens] + + +def eow_tag_tokens(tokens: Sequence[str], eow_tag: str = ""): + """Applies a end of word (EoW) marker to every token in the tokens sequence.""" + return [t + eow_tag for t in tokens] + + +def _process_wordpiece_tokens(tokens: Sequence[str]): + return [_process_wordpiece_token_for_alignment(token) for token in tokens] + + +def _process_sentencepiece_tokens(tokens: Sequence[str]): + return [_process_sentencepiece_token_for_alignment(token) for token in tokens] + + +def _process_bytebpe_tokens(tokens: Sequence[str]): + return [_process_bytebpe_token_for_alignment(token) for token in tokens] + + +def _process_wordpiece_token_for_alignment(t): + """Add word boundary markers, removes token prefix (no-space meta-symbol — '##' for BERT).""" + if t.startswith("##"): + return re.sub(r"^##", "", t) + else: + return "" + t + + +def _process_sentencepiece_token_for_alignment(t): + """Add word boundary markers, removes token prefix (space meta-symbol).""" + if t.startswith("▁"): + return "" + re.sub(r"^▁", "", t) + else: + return t + + +def _process_bytebpe_token_for_alignment(t): + """Add word boundary markers, removes token prefix (space meta-symbol).""" + if t.startswith("Ġ"): + return "" + re.sub(r"^Ġ", "", t) + else: + return t diff --git a/tests/utils/test_tokenization_normalization.py b/tests/utils/test_tokenization_normalization.py index c07ae44bf..08f0f0b6c 100644 --- a/tests/utils/test_tokenization_normalization.py +++ b/tests/utils/test_tokenization_normalization.py @@ -1,6 +1,7 @@ import pytest import jiant.utils.tokenization_normalization as tn +import jiant.utils.tokenization_utils as tu from transformers import BertTokenizer, XLMTokenizer, RobertaTokenizer, AlbertTokenizer @@ -52,7 +53,7 @@ def test_process_wordpiece_token_sequence(): "rules", ".", ] - adjusted_wordpiece_tokens = tn._process_wordpiece_tokens(original_wordpiece_tokens) + adjusted_wordpiece_tokens = tu._process_wordpiece_tokens(original_wordpiece_tokens) assert adjusted_wordpiece_tokens == expected_adjusted_wordpiece_tokens @@ -103,7 +104,7 @@ def test_process_sentencepiece_token_sequence(): "▁rules", ".", ] - adjusted_sentencepiece_tokens = tn._process_sentencepiece_tokens(original_sentencepiece_tokens) + adjusted_sentencepiece_tokens = tu._process_sentencepiece_tokens(original_sentencepiece_tokens) assert adjusted_sentencepiece_tokens == expected_adjusted_sentencepiece_tokens @@ -144,7 +145,7 @@ def test_process_bytebpe_token_sequence(): "Ġrules", ".", ] - adjusted_bytebpe_tokens = tn._process_bytebpe_tokens(original_bytebpe_tokens) + adjusted_bytebpe_tokens = tu._process_bytebpe_tokens(original_bytebpe_tokens) assert adjusted_bytebpe_tokens == expected_adjusted_bytebpe_tokens From 391bd59f5943cfba2d0c3e58a1332eb100b43ba2 Mon Sep 17 00:00:00 2001 From: jeswan <57466294+jeswan@users.noreply.github.com> Date: Thu, 11 Mar 2021 10:09:45 -0500 Subject: [PATCH 2/4] Update jiant/shared/model_resolution.py --- jiant/shared/model_resolution.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jiant/shared/model_resolution.py b/jiant/shared/model_resolution.py index b4eab62fe..2f7b19175 100644 --- a/jiant/shared/model_resolution.py +++ b/jiant/shared/model_resolution.py @@ -64,7 +64,6 @@ def resolve_model_arch_tokenizer(tokenizer): ModelArchitecture """ - print(TOKENIZER_CLASS_DICT.inverse) assert len(TOKENIZER_CLASS_DICT.inverse[tokenizer.__class__]) == 1 return TOKENIZER_CLASS_DICT.inverse[tokenizer.__class__][0] From 33faa2569ad7b95b9f7371cee870195851909392 Mon Sep 17 00:00:00 2001 From: Jesse Swanson Date: Thu, 18 Mar 2021 21:26:49 -0400 Subject: [PATCH 3/4] code review feedback: tokenization utils public and abstract tokenization normalization --- jiant/proj/main/modeling/primary.py | 19 ++++++++++++------- jiant/utils/tokenization_utils.py | 18 +++++++++--------- .../utils/test_tokenization_normalization.py | 6 +++--- 3 files changed, 24 insertions(+), 19 deletions(-) diff --git a/jiant/proj/main/modeling/primary.py b/jiant/proj/main/modeling/primary.py index d6f1bd9c0..cb5c9a214 100644 --- a/jiant/proj/main/modeling/primary.py +++ b/jiant/proj/main/modeling/primary.py @@ -21,9 +21,9 @@ from jiant.utils.tokenization_utils import bow_tag_tokens from jiant.utils.tokenization_utils import eow_tag_tokens -from jiant.utils.tokenization_utils import _process_bytebpe_tokens -from jiant.utils.tokenization_utils import _process_wordpiece_tokens -from jiant.utils.tokenization_utils import _process_sentencepiece_tokens +from jiant.utils.tokenization_utils import process_bytebpe_tokens +from jiant.utils.tokenization_utils import process_wordpiece_tokens +from jiant.utils.tokenization_utils import process_sentencepiece_tokens @dataclass @@ -175,6 +175,11 @@ def __init__(self, baseObject): ) self.__dict__ = baseObject.__dict__ + @classmethod + @abc.abstractmethod + def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization): + pass + @abc.abstractmethod def get_mlm_weights_dict(self, weights_dict): pass @@ -214,7 +219,7 @@ def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenizat if tokenizer.init_kwargs.get("do_lower_case", False): space_tokenization = [token.lower() for token in space_tokenization] modifed_space_tokenization = bow_tag_tokens(space_tokenization) - modifed_target_tokenization = _process_wordpiece_tokens(target_tokenization) + modifed_target_tokenization = process_wordpiece_tokens(target_tokenization) return modifed_space_tokenization, modifed_target_tokenization @@ -256,7 +261,7 @@ def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenizat """See tokenization_normalization.py for details""" modifed_space_tokenization = bow_tag_tokens(space_tokenization) modifed_target_tokenization = ["Ġ" + target_tokenization[0]] + target_tokenization[1:] - modifed_target_tokenization = _process_bytebpe_tokens(modifed_target_tokenization) + modifed_target_tokenization = process_bytebpe_tokens(modifed_target_tokenization) return modifed_space_tokenization, modifed_target_tokenization @@ -296,7 +301,7 @@ def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenizat """See tokenization_normalization.py for details""" space_tokenization = [token.lower() for token in space_tokenization] modifed_space_tokenization = bow_tag_tokens(space_tokenization) - modifed_target_tokenization = _process_sentencepiece_tokens(target_tokenization) + modifed_target_tokenization = process_sentencepiece_tokens(target_tokenization) return modifed_space_tokenization, modifed_target_tokenization @@ -366,7 +371,7 @@ def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenizat """See tokenization_normalization.py for details""" space_tokenization = [token.lower() for token in space_tokenization] modifed_space_tokenization = bow_tag_tokens(space_tokenization) - modifed_target_tokenization = _process_sentencepiece_tokens(target_tokenization) + modifed_target_tokenization = process_sentencepiece_tokens(target_tokenization) return modifed_space_tokenization, modifed_target_tokenization diff --git a/jiant/utils/tokenization_utils.py b/jiant/utils/tokenization_utils.py index 94041e5ec..7e250348d 100644 --- a/jiant/utils/tokenization_utils.py +++ b/jiant/utils/tokenization_utils.py @@ -13,19 +13,19 @@ def eow_tag_tokens(tokens: Sequence[str], eow_tag: str = ""): return [t + eow_tag for t in tokens] -def _process_wordpiece_tokens(tokens: Sequence[str]): - return [_process_wordpiece_token_for_alignment(token) for token in tokens] +def process_wordpiece_tokens(tokens: Sequence[str]): + return [process_wordpiece_token_for_alignment(token) for token in tokens] -def _process_sentencepiece_tokens(tokens: Sequence[str]): - return [_process_sentencepiece_token_for_alignment(token) for token in tokens] +def process_sentencepiece_tokens(tokens: Sequence[str]): + return [process_sentencepiece_token_for_alignment(token) for token in tokens] -def _process_bytebpe_tokens(tokens: Sequence[str]): - return [_process_bytebpe_token_for_alignment(token) for token in tokens] +def process_bytebpe_tokens(tokens: Sequence[str]): + return [process_bytebpe_token_for_alignment(token) for token in tokens] -def _process_wordpiece_token_for_alignment(t): +def process_wordpiece_token_for_alignment(t): """Add word boundary markers, removes token prefix (no-space meta-symbol — '##' for BERT).""" if t.startswith("##"): return re.sub(r"^##", "", t) @@ -33,7 +33,7 @@ def _process_wordpiece_token_for_alignment(t): return "" + t -def _process_sentencepiece_token_for_alignment(t): +def process_sentencepiece_token_for_alignment(t): """Add word boundary markers, removes token prefix (space meta-symbol).""" if t.startswith("▁"): return "" + re.sub(r"^▁", "", t) @@ -41,7 +41,7 @@ def _process_sentencepiece_token_for_alignment(t): return t -def _process_bytebpe_token_for_alignment(t): +def process_bytebpe_token_for_alignment(t): """Add word boundary markers, removes token prefix (space meta-symbol).""" if t.startswith("Ġ"): return "" + re.sub(r"^Ġ", "", t) diff --git a/tests/utils/test_tokenization_normalization.py b/tests/utils/test_tokenization_normalization.py index 08f0f0b6c..71fe262c1 100644 --- a/tests/utils/test_tokenization_normalization.py +++ b/tests/utils/test_tokenization_normalization.py @@ -53,7 +53,7 @@ def test_process_wordpiece_token_sequence(): "rules", ".", ] - adjusted_wordpiece_tokens = tu._process_wordpiece_tokens(original_wordpiece_tokens) + adjusted_wordpiece_tokens = tu.process_wordpiece_tokens(original_wordpiece_tokens) assert adjusted_wordpiece_tokens == expected_adjusted_wordpiece_tokens @@ -104,7 +104,7 @@ def test_process_sentencepiece_token_sequence(): "▁rules", ".", ] - adjusted_sentencepiece_tokens = tu._process_sentencepiece_tokens(original_sentencepiece_tokens) + adjusted_sentencepiece_tokens = tu.process_sentencepiece_tokens(original_sentencepiece_tokens) assert adjusted_sentencepiece_tokens == expected_adjusted_sentencepiece_tokens @@ -145,7 +145,7 @@ def test_process_bytebpe_token_sequence(): "Ġrules", ".", ] - adjusted_bytebpe_tokens = tu._process_bytebpe_tokens(original_bytebpe_tokens) + adjusted_bytebpe_tokens = tu.process_bytebpe_tokens(original_bytebpe_tokens) assert adjusted_bytebpe_tokens == expected_adjusted_bytebpe_tokens From ad49c646447b6ff4447d1e3b7518d4a65d961952 Mon Sep 17 00:00:00 2001 From: Jesse Swanson Date: Fri, 19 Mar 2021 10:32:58 -0400 Subject: [PATCH 4/4] implement abstract methods for JiantTransformerModels --- jiant/proj/main/modeling/primary.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/jiant/proj/main/modeling/primary.py b/jiant/proj/main/modeling/primary.py index cb5c9a214..4b991ff97 100644 --- a/jiant/proj/main/modeling/primary.py +++ b/jiant/proj/main/modeling/primary.py @@ -423,6 +423,13 @@ def get_feat_spec(self, max_seq_length): sep_token_extra=False, ) + @classmethod + def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization): + raise NotImplementedError() + + def get_mlm_weights_dict(self, weights_dict): + raise NotImplementedError() + @JiantTransformersModelFactory.register(ModelArchitectures.BART) class JiantBartModel(JiantTransformersModel): @@ -475,12 +482,19 @@ def __call__(self, encoder, input_ids, input_mask): pooled = unpooled[batch_idx, slen - input_ids.eq(encoder.config.pad_token_id).sum(1) - 1] return JiantModelOutput(pooled=pooled, unpooled=unpooled, other=other) + def get_mlm_weights_dict(self, weights_dict): + raise NotImplementedError() + @JiantTransformersModelFactory.register(ModelArchitectures.MBART) class JiantMBartModel(JiantBartModel): def __init__(self, baseObject): super().__init__(baseObject) + @classmethod + def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization): + raise NotImplementedError() + def get_feat_spec(self, max_seq_length): # mBART is weird # token 0 = '' which is the cls_token @@ -498,3 +512,6 @@ def get_feat_spec(self, max_seq_length): sequence_b_segment_id=0, # mBART has no token_type_ids sep_token_extra=True, ) + + def get_mlm_weights_dict(self, weights_dict): + raise NotImplementedError()