Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move tokenization logic to central JiantModelTransformers method #1290

Merged
merged 4 commits into from Mar 19, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
78 changes: 66 additions & 12 deletions jiant/proj/main/modeling/primary.py
@@ -1,6 +1,7 @@
import abc

from dataclasses import dataclass

from typing import Any
from typing import Callable
from typing import Dict
Expand All @@ -9,14 +10,20 @@
import torch
import torch.nn as nn

import jiant.tasks as tasks
import jiant.utils.python.strings as strings
from jiant.tasks.core import BatchMixin
from jiant.tasks.core import FeaturizationSpec
from jiant.tasks.core import Task

from jiant.proj.main.components.outputs import construct_output_from_dict
from jiant.proj.main.modeling.taskmodels import Taskmodel
from jiant.shared.model_resolution import ModelArchitectures
from jiant.tasks.core import FeaturizationSpec

from jiant.proj.main.components.outputs import construct_output_from_dict
from jiant.utils.tokenization_utils import bow_tag_tokens
from jiant.utils.tokenization_utils import eow_tag_tokens
from jiant.utils.tokenization_utils import _process_bytebpe_tokens
from jiant.utils.tokenization_utils import _process_wordpiece_tokens
from jiant.utils.tokenization_utils import _process_sentencepiece_tokens


@dataclass
Expand All @@ -29,7 +36,7 @@ class JiantModelOutput:
class JiantModel(nn.Module):
def __init__(
self,
task_dict: Dict[str, tasks.Task],
task_dict: Dict[str, Task],
encoder: nn.Module,
taskmodels_dict: Dict[str, Taskmodel],
task_to_taskmodel_map: Dict[str, str],
Expand All @@ -42,15 +49,15 @@ def __init__(
self.task_to_taskmodel_map = task_to_taskmodel_map
self.tokenizer = tokenizer

def forward(self, batch: tasks.BatchMixin, task: tasks.Task, compute_loss: bool = False):
def forward(self, batch: BatchMixin, task: Task, compute_loss: bool = False):
"""Calls to this forward method are delegated to the forward of the appropriate taskmodel.

When JiantModel forward is called, the task name from the task argument is used as a key
to select the appropriate submodule/taskmodel, and that taskmodel's forward is called.

Args:
batch (tasks.BatchMixin): model input.
task (tasks.Task): task to which to delegate the forward call.
batch (BatchMixin): model input.
task (Task): task to which to delegate the forward call.
compute_loss (bool): whether to calculate and return the loss.

Returns:
Expand All @@ -74,20 +81,20 @@ def forward(self, batch: tasks.BatchMixin, task: tasks.Task, compute_loss: bool

def wrap_jiant_forward(
jiant_model: Union[JiantModel, nn.DataParallel],
batch: tasks.BatchMixin,
task: tasks.Task,
batch: BatchMixin,
task: Task,
compute_loss: bool = False,
):
"""Wrapper to repackage model inputs using dictionaries for compatibility with DataParallel.

Wrapper that converts batches (type tasks.BatchMixin) to dictionaries before delegating to
Wrapper that converts batches (type BatchMixin) to dictionaries before delegating to
JiantModel's forward method, and then converts the resulting model output dict into the
appropriate model output dataclass.

Args:
jiant_model (Union[JiantModel, nn.DataParallel]):
batch (tasks.BatchMixin): model input batch.
task (tasks.Task): Task object passed for access in the taskmodel.
batch (BatchMixin): model input batch.
task (Task): Task object passed for access in the taskmodel.
compute_loss (bool): True if loss should be computed, False otherwise.

Returns:
Expand Down Expand Up @@ -201,6 +208,16 @@ class JiantBertModel(JiantTransformersModel):
def __init__(self, baseObject):
super().__init__(baseObject)

@classmethod
def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization):
"""See tokenization_normalization.py for details"""
if tokenizer.init_kwargs.get("do_lower_case", False):
space_tokenization = [token.lower() for token in space_tokenization]
modifed_space_tokenization = bow_tag_tokens(space_tokenization)
modifed_target_tokenization = _process_wordpiece_tokens(target_tokenization)

return modifed_space_tokenization, modifed_target_tokenization

def get_feat_spec(self, max_seq_length):
return FeaturizationSpec(
max_seq_length=max_seq_length,
Expand Down Expand Up @@ -234,6 +251,15 @@ class JiantRobertaModel(JiantTransformersModel):
def __init__(self, baseObject):
super().__init__(baseObject)

@classmethod
def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization):
"""See tokenization_normalization.py for details"""
modifed_space_tokenization = bow_tag_tokens(space_tokenization)
modifed_target_tokenization = ["Ġ" + target_tokenization[0]] + target_tokenization[1:]
modifed_target_tokenization = _process_bytebpe_tokens(modifed_target_tokenization)

return modifed_space_tokenization, modifed_target_tokenization

def get_mlm_weights_dict(self, weights_dict):
mlm_weights_dict = {
strings.remove_prefix(k, "lm_head."): v for k, v in weights_dict.items()
Expand Down Expand Up @@ -265,6 +291,15 @@ class JiantXLMRobertaModel(JiantTransformersModel):
def __init__(self, baseObject):
super().__init__(baseObject)

@classmethod
def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization):
"""See tokenization_normalization.py for details"""
space_tokenization = [token.lower() for token in space_tokenization]
modifed_space_tokenization = bow_tag_tokens(space_tokenization)
modifed_target_tokenization = _process_sentencepiece_tokens(target_tokenization)

return modifed_space_tokenization, modifed_target_tokenization

def get_feat_spec(self, max_seq_length):
# XLM-RoBERTa is weird
# token 0 = '<s>' which is the cls_token
Expand Down Expand Up @@ -296,6 +331,16 @@ class JiantXLMModel(JiantTransformersModel):
def __init__(self, baseObject):
super().__init__(baseObject)

@classmethod
def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization):
"""See tokenization_normalization.py for details"""
if tokenizer.init_kwargs.get("do_lowercase_and_remove_accent", False):
space_tokenization = [token.lower() for token in space_tokenization]
modifed_space_tokenization = eow_tag_tokens(space_tokenization)
modifed_target_tokenization = target_tokenization

return modifed_space_tokenization, modifed_target_tokenization

def get_feat_spec(self, max_seq_length):
return FeaturizationSpec(
max_seq_length=max_seq_length,
Expand All @@ -316,6 +361,15 @@ class JiantAlbertModel(JiantTransformersModel):
def __init__(self, baseObject):
super().__init__(baseObject)

@classmethod
def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization):
"""See tokenization_normalization.py for details"""
space_tokenization = [token.lower() for token in space_tokenization]
modifed_space_tokenization = bow_tag_tokens(space_tokenization)
modifed_target_tokenization = _process_sentencepiece_tokens(target_tokenization)

return modifed_space_tokenization, modifed_target_tokenization

def get_mlm_weights_dict(self, weights_dict):
mlm_weights_dict = {
strings.remove_prefix(k, "predictions."): v for k, v in weights_dict.items()
Expand Down
37 changes: 27 additions & 10 deletions jiant/shared/model_resolution.py
@@ -1,5 +1,6 @@
from dataclasses import dataclass
from enum import Enum
from jiant.utils.python.datastructures import BiDict

import transformers

Expand All @@ -26,16 +27,18 @@ class ModelClassSpec:
model_class: type


TOKENIZER_CLASS_DICT = {
ModelArchitectures.BERT: transformers.BertTokenizer,
ModelArchitectures.XLM: transformers.XLMTokenizer,
ModelArchitectures.ROBERTA: transformers.RobertaTokenizer,
ModelArchitectures.XLM_ROBERTA: transformers.XLMRobertaTokenizer,
ModelArchitectures.ALBERT: transformers.AlbertTokenizer,
ModelArchitectures.BART: transformers.BartTokenizer,
ModelArchitectures.MBART: transformers.MBartTokenizer,
ModelArchitectures.ELECTRA: transformers.ElectraTokenizer,
}
TOKENIZER_CLASS_DICT = BiDict(
{
ModelArchitectures.BERT: transformers.BertTokenizer,
ModelArchitectures.XLM: transformers.XLMTokenizer,
ModelArchitectures.ROBERTA: transformers.RobertaTokenizer,
ModelArchitectures.XLM_ROBERTA: transformers.XLMRobertaTokenizer,
ModelArchitectures.ALBERT: transformers.AlbertTokenizer,
ModelArchitectures.BART: transformers.BartTokenizer,
ModelArchitectures.MBART: transformers.MBartTokenizer,
ModelArchitectures.ELECTRA: transformers.ElectraTokenizer,
}
)


def resolve_tokenizer_class(model_type):
Expand All @@ -51,6 +54,20 @@ def resolve_tokenizer_class(model_type):
return TOKENIZER_CLASS_DICT[ModelArchitectures(model_type)]


def resolve_model_arch_tokenizer(tokenizer):
"""Get the model architecture for a given tokenizer.

Args:
tokenizer

Returns:
ModelArchitecture

"""
assert len(TOKENIZER_CLASS_DICT.inverse[tokenizer.__class__]) == 1
return TOKENIZER_CLASS_DICT.inverse[tokenizer.__class__][0]


def resolve_is_lower_case(tokenizer):
if isinstance(tokenizer, transformers.BertTokenizer):
return tokenizer.basic_tokenizer.do_lower_case
Expand Down
41 changes: 41 additions & 0 deletions jiant/utils/python/datastructures.py
Expand Up @@ -279,3 +279,44 @@ def get_maps(self) -> Tuple[Dict, Dict]:

"""
return self.a_to_b, self.b_to_a


class BiDict(dict):
"""Maintains bidirectional dict

Example:
bd = BiDict({'a': 1, 'b': 2})
print(bd) # {'a': 1, 'b': 2}
print(bd.inverse) # {1: ['a'], 2: ['b']}
bd['c'] = 1 # Now two keys have the same value (= 1)
print(bd) # {'a': 1, 'c': 1, 'b': 2}
print(bd.inverse) # {1: ['a', 'c'], 2: ['b']}
del bd['c']
print(bd) # {'a': 1, 'b': 2}
print(bd.inverse) # {1: ['a'], 2: ['b']}
del bd['a']
print(bd) # {'b': 2}
print(bd.inverse) # {2: ['b']}
bd['b'] = 3
print(bd) # {'b': 3}
print(bd.inverse) # {2: [], 3: ['b']}

"""

def __init__(self, *args, **kwargs):
super(BiDict, self).__init__(*args, **kwargs)
self.inverse = {}
for key, value in self.items():
self.inverse.setdefault(value, []).append(key)

def __setitem__(self, key, value):
if key in self:
self.inverse[self[key]].remove(key)
super(BiDict, self).__setitem__(key, value)
self.inverse.setdefault(value, []).append(key)

def __delitem__(self, key):
self.inverse.setdefault(self[key], []).remove(key)
if self[key] in self.inverse and not self.inverse[self[key]]:
del self.inverse[self[key]]
super(BiDict, self).__delitem__(key)
89 changes: 17 additions & 72 deletions jiant/utils/tokenization_normalization.py
Expand Up @@ -8,11 +8,12 @@

"""

import re
import transformers
from typing import Sequence

from jiant.utils.testing import utils as test_utils
from jiant.shared.model_resolution import resolve_model_arch_tokenizer
from jiant.proj.main.modeling.primary import JiantTransformersModelFactory


def normalize_tokenizations(
Expand Down Expand Up @@ -49,80 +50,24 @@ def normalize_tokenizations(
if len(space_tokenization) == 0 or len(target_tokenization) == 0:
raise ValueError("Empty token sequence.")

if isinstance(tokenizer, transformers.BertTokenizer):
if tokenizer.init_kwargs.get("do_lower_case", False):
space_tokenization = [token.lower() for token in space_tokenization]
modifed_space_tokenization = bow_tag_tokens(space_tokenization)
modifed_target_tokenization = _process_wordpiece_tokens(target_tokenization)
elif isinstance(tokenizer, transformers.XLMTokenizer):
if tokenizer.init_kwargs.get("do_lowercase_and_remove_accent", False):
space_tokenization = [token.lower() for token in space_tokenization]
modifed_space_tokenization = eow_tag_tokens(space_tokenization)
modifed_target_tokenization = target_tokenization
elif isinstance(tokenizer, transformers.RobertaTokenizer):
modifed_space_tokenization = bow_tag_tokens(space_tokenization)
modifed_target_tokenization = ["Ġ" + target_tokenization[0]] + target_tokenization[1:]
modifed_target_tokenization = _process_bytebpe_tokens(modifed_target_tokenization)
elif isinstance(tokenizer, (transformers.AlbertTokenizer, transformers.XLMRobertaTokenizer)):
space_tokenization = [token.lower() for token in space_tokenization]
modifed_space_tokenization = bow_tag_tokens(space_tokenization)
modifed_target_tokenization = _process_sentencepiece_tokens(target_tokenization)
else:
if test_utils.is_pytest():
from jiant.utils.testing.tokenizer import SimpleSpaceTokenizer

if isinstance(tokenizer, SimpleSpaceTokenizer):
return space_tokenization, target_tokenization
raise ValueError("Tokenizer not supported.")
if test_utils.is_pytest():
from jiant.utils.testing.tokenizer import SimpleSpaceTokenizer

if isinstance(tokenizer, SimpleSpaceTokenizer):
return space_tokenization, target_tokenization

model_arch = resolve_model_arch_tokenizer(tokenizer)
print(model_arch)
jiant_transformer_model_class = JiantTransformersModelFactory.get_registry()[model_arch]
(
modifed_space_tokenization,
modifed_target_tokenization,
) = jiant_transformer_model_class.normalize_tokenizations(
tokenizer, space_tokenization, target_tokenization
)

# safety check: if normalization changed sequence length, alignment is likely to break.
assert len(modifed_space_tokenization) == len(space_tokenization)
assert len(modifed_target_tokenization) == len(target_tokenization)

return modifed_space_tokenization, modifed_target_tokenization


def bow_tag_tokens(tokens: Sequence[str], bow_tag: str = "<w>"):
"""Applies a beginning of word (BoW) marker to every token in the tokens sequence."""
return [bow_tag + t for t in tokens]


def eow_tag_tokens(tokens: Sequence[str], eow_tag: str = "</w>"):
"""Applies a end of word (EoW) marker to every token in the tokens sequence."""
return [t + eow_tag for t in tokens]


def _process_wordpiece_tokens(tokens: Sequence[str]):
return [_process_wordpiece_token_for_alignment(token) for token in tokens]


def _process_sentencepiece_tokens(tokens: Sequence[str]):
return [_process_sentencepiece_token_for_alignment(token) for token in tokens]


def _process_bytebpe_tokens(tokens: Sequence[str]):
return [_process_bytebpe_token_for_alignment(token) for token in tokens]


def _process_wordpiece_token_for_alignment(t):
"""Add word boundary markers, removes token prefix (no-space meta-symbol — '##' for BERT)."""
if t.startswith("##"):
return re.sub(r"^##", "", t)
else:
return "<w>" + t


def _process_sentencepiece_token_for_alignment(t):
"""Add word boundary markers, removes token prefix (space meta-symbol)."""
if t.startswith("▁"):
return "<w>" + re.sub(r"^▁", "", t)
else:
return t


def _process_bytebpe_token_for_alignment(t):
"""Add word boundary markers, removes token prefix (space meta-symbol)."""
if t.startswith("Ġ"):
return "<w>" + re.sub(r"^Ġ", "", t)
else:
return t