Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
BERTTensorizerBaseImpl to reimplement BERTTensorizerBase to be TorchS…
Browse files Browse the repository at this point in the history
…criptable (#1163)

Summary:
Pull Request resolved: #1163

BERTTensorizerBaseImpl to reimplement BERTTensorizerBase to be TorchScriptable
Over design:

PyText Tensorizer (for example: RoBERTaTensorizer) will delegate the numberize and tensorize logic to Scripted Tensorizer Implementation (for example: RoBERTaTensorizerImpl)

This requires to reimplement numberize() and tensorize() logic in Torchscriptable, but good news is that we already have such implementation in pytext/torchscript/tensorizer, we just need to make minor change.

On the PyText Tensorizer side, it will delegate numberize and tensorize logic to tensorizer_impl.
```
def numberize(self, row: Dict) -> Tuple[Any, ...]:
	per_sentence_tokens = [
            self.tokenizer.tokenize(row[column]) for column in self.columns
        ]
        return self.tensorizer_impl.numberize(per_sentence_tokens)

def tensorize(self, batch) -> Tuple[torch.Tensor, ...]:
	tokens, segment_labels, seq_lens, positions = zip(*batch)
        return self.tensorizer_impl.tensorize(
            tokens, segment_labels, seq_lens, positions
        )
```

Reviewed By: rutyrinott

Differential Revision: D18651538

fbshipit-source-id: aa56e84716496b73c021b70f996734215fb8f9ab
  • Loading branch information
chenyangyu1988 authored and facebook-github-bot committed Dec 13, 2019
1 parent e780023 commit 39467dc
Show file tree
Hide file tree
Showing 5 changed files with 353 additions and 7 deletions.
205 changes: 204 additions & 1 deletion pytext/data/bert_tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fairseq.data.dictionary import Dictionary
from fairseq.data.legacy.masked_lm_dictionary import BertDictionary
from pytext.config.component import ComponentType, create_component
from pytext.data.tensorizers import Tensorizer, lookup_tokens
from pytext.data.tensorizers import Tensorizer, TensorizerScriptImpl, lookup_tokens
from pytext.data.tokenizers import Tokenizer, WordPieceTokenizer
from pytext.data.utils import (
BOS,
Expand All @@ -21,6 +21,8 @@
pad_and_tensorize,
)
from pytext.torchscript.tensorizer import ScriptBERTTensorizer
from pytext.torchscript.tensorizer.tensorizer import VocabLookup
from pytext.torchscript.utils import pad_2d, pad_2d_mask
from pytext.torchscript.vocab import ScriptVocabulary
from pytext.utils.file_io import PathManager

Expand All @@ -47,6 +49,207 @@ def build_fairseq_vocab(
)


class BERTTensorizerBaseScriptImpl(TensorizerScriptImpl):
def __init__(self, tokenizer: Tokenizer, vocab: Vocabulary, max_seq_len: int):
super().__init__()
self.tokenizer = tokenizer.torchscriptify()
self.vocab = ScriptVocabulary(
list(vocab),
pad_idx=vocab.get_pad_index(),
bos_idx=vocab.get_bos_index(-1),
eos_idx=vocab.get_eos_index(-1),
unk_idx=vocab.get_unk_index(),
)
self.vocab_lookup = VocabLookup(self.vocab)
self.max_seq_len = max_seq_len

def _lookup_tokens(
self, tokens: List[Tuple[str, int, int]], max_seq_len: Optional[int] = None
) -> Tuple[List[int], List[int], List[int]]:
"""
This function knows how to call lookup_tokens with the correct
settings for this model. The default behavior is to wrap the
numberized text with distinct BOS and EOS tokens. The resulting
vector would look something like this:
[BOS, token1_id, . . . tokenN_id, EOS]
The function also takes an optional seq_len parameter which is
used to customize truncation in case we have multiple text fields.
By default max_seq_len is used. It's upto the numberize function of
the class to decide how to use the seq_len param.
For example:
- In the case of sentence pair classification, we might want both
pieces of text have the same length which is half of the
max_seq_len supported by the model.
- In the case of QA, we might want to truncate the context by a
seq_len which is longer than what we use for the question.
Args:
tokens: a list of tokens represent a sentence, each token represented
by token string, start and end indices.
Returns:
tokens_ids: List[int], a list of token ids represent a sentence.
start_indices: List[int], each token start indice in the sentence.
end_indices: List[int], each token end indice in the sentence.
"""
if max_seq_len is None:
max_seq_len = self.max_seq_len

return self.vocab_lookup(
tokens,
bos_idx=self.vocab.bos_idx,
eos_idx=self.vocab.eos_idx,
use_eos_token_for_bos=False,
max_seq_len=max_seq_len,
)

def _wrap_numberized_tokens(
self, numberized_tokens: List[int], idx: int
) -> List[int]:
"""
If a class has a non-standard way of generating the final numberized text
(eg: BERT) then a class specific version of wrap_numberized_text function
should be implemented. This allows us to share the numberize
function across classes without having to copy paste code. The default
implementation doesnt do anything.
"""
return numberized_tokens

def numberize(
self, per_sentence_tokens: List[List[Tuple[str, int, int]]]
) -> Tuple[List[int], List[int], int, List[int]]:
"""
This function contains logic for converting tokens into ids based on
the specified vocab. It also outputs, for each instance, the vectors
needed to run the actual model.
Args:
per_sentence_tokens: list of tokens per sentence level in one row,
each token represented by token string, start and end indices.
Returns:
tokens: List[int], a list of token ids, concatenate all
sentences token ids.
segment_labels: List[int], denotes each token belong to
which sentence.
seq_len: int, tokens length
positions: List[int], token positions
"""
tokens: List[int] = []
segment_labels: List[int] = []
seq_len: int = 0
positions: List[int] = []

for idx, single_sentence_tokens in enumerate(per_sentence_tokens):
lookup_ids: List[int] = self._lookup_tokens(single_sentence_tokens)[0]
lookup_ids = self._wrap_numberized_tokens(lookup_ids, idx)

tokens.extend(lookup_ids)
segment_labels.extend([idx] * len(lookup_ids))

seq_len = len(tokens)
positions = [i for i in range(seq_len)]
return tokens, segment_labels, seq_len, positions

def tensorize(
self,
tokens_2d: List[List[int]],
segment_labels_2d: List[List[int]],
seq_lens_1d: List[int],
positions_2d: List[List[int]],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Convert instance level vectors into batch level tensors.
"""
tokens, pad_mask = pad_2d_mask(tokens_2d, pad_value=self.vocab.pad_idx)
segment_labels = torch.tensor(
pad_2d(segment_labels_2d, seq_lens=seq_lens_1d, pad_idx=0), dtype=torch.long
)
positions = torch.tensor(
pad_2d(positions_2d, seq_lens=seq_lens_1d, pad_idx=0), dtype=torch.long
)
if self.device == "":
return tokens, pad_mask, segment_labels, positions
else:
return (
tokens.to(self.device),
pad_mask.to(self.device),
segment_labels.to(self.device),
positions.to(self.device),
)

def tokenize(
self,
row_text: Optional[List[str]],
row_pre_tokenized: Optional[List[List[str]]],
) -> List[List[Tuple[str, int, int]]]:
"""
This function convert raw inputs into tokens, each token is represented
by token(str), start and end indices in the raw inputs. There are two
possible inputs to this function depends if the tokenized in implemented
in TorchScript or not.
Case 1: Tokenizer has a full TorchScript implementation, the input will
be a list of sentences (in most case it is single sentence or a pair).
Case 2: Tokenizer have partial or no TorchScript implementation, in most
case, the tokenizer will be host in Yoda, the input will be a list of
pre-processed tokens.
Returns:
per_sentence_tokens: tokens per setence level, each token is
represented by token(str), start and end indices.
"""
per_sentence_tokens: List[List[Tuple[str, int, int]]] = []

if row_text is not None:
for text in row_text:
per_sentence_tokens.append(self.tokenizer.tokenize(text))
elif row_pre_tokenized is not None:
for sentence_pre_tokenized in row_pre_tokenized:
sentence_tokens: List[Tuple[str, int, int]] = []
for token in sentence_pre_tokenized:
sentence_tokens.extend(self.tokenizer.tokenize(token))
per_sentence_tokens.append(sentence_tokens)

return per_sentence_tokens

def forward(
self,
texts: Optional[List[List[str]]] = None,
pre_tokenized: Optional[List[List[List[str]]]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Wire up tokenize(), numberize() and tensorize() functions for data
processing.
When export to TorchScript, the wrapper module should choose to use
texts or pre_tokenized based on the TorchScript tokenizer
implementation (e.g use external tokenizer such as Yoda or not).
"""
tokens_2d: List[List[int]] = []
segment_labels_2d: List[List[int]] = []
seq_lens_1d: List[int] = []
positions_2d: List[List[int]] = []

for idx in range(self.batch_size(texts, pre_tokenized)):
tokens: List[List[Tuple[str, int, int]]] = self.tokenize(
self.get_texts_by_index(texts, idx),
self.get_tokens_by_index(pre_tokenized, idx),
)

numberized: Tuple[List[int], List[int], int, List[int]] = self.numberize(
tokens
)
tokens_2d.append(numberized[0])
segment_labels_2d.append(numberized[1])
seq_lens_1d.append(numberized[2])
positions_2d.append(numberized[3])

return self.tensorize(tokens_2d, segment_labels_2d, seq_lens_1d, positions_2d)


class BERTTensorizerBase(Tensorizer):
"""
Base Tensorizer class for all BERT style models including XLM,
Expand Down
12 changes: 11 additions & 1 deletion pytext/data/roberta_tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,25 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from pytext.config.component import ComponentType, create_component
from pytext.data.bert_tensorizer import BERTTensorizerBase, build_fairseq_vocab
from pytext.data.bert_tensorizer import (
BERTTensorizerBase,
BERTTensorizerBaseScriptImpl,
build_fairseq_vocab,
)
from pytext.data.tokenizers import GPT2BPETokenizer, Tokenizer
from pytext.data.utils import BOS, EOS, MASK, PAD, UNK
from pytext.torchscript.tensorizer import ScriptRoBERTaTensorizer
from pytext.torchscript.vocab import ScriptVocabulary
from pytext.utils.file_io import PathManager


RoBERTaTensorizerScriptImpl = BERTTensorizerBaseScriptImpl


class RoBERTaTensorizer(BERTTensorizerBase):

__TENSORIZER_SCRIPT_IMPL__ = RoBERTaTensorizerScriptImpl

class Config(BERTTensorizerBase.Config):
vocab_file: str = (
"manifold://pytext_training/tree/static/vocabs/bpe/gpt2/dict.txt"
Expand Down
77 changes: 77 additions & 0 deletions pytext/data/tensorizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,82 @@ def lookup_tokens(
return tokens, start_idx, end_idx


class TensorizerScriptImpl(torch.nn.Module):
def __init__(self):
super().__init__()
self.device: str = ""

@torch.jit.export
def set_device(self, device: str):
self.device = device

def batch_size(
self, texts: Optional[List[List[str]]], tokens: Optional[List[List[List[str]]]]
) -> int:
if texts is not None:
return len(texts)
elif tokens is not None:
return len(tokens)
else:
raise RuntimeError("Empty input for both texts and tokens.")

def row_size(
self, texts: Optional[List[List[str]]], tokens: Optional[List[List[List[str]]]]
) -> int:
if texts is not None:
return len(texts[0])
elif tokens is not None:
return len(tokens[0])
else:
raise RuntimeError("Empty input for both texts and tokens.")

def get_texts_by_index(
self, texts: Optional[List[List[str]]], index: int
) -> Optional[List[str]]:
if texts is None:
return None
return texts[index]

def get_tokens_by_index(
self, tokens: Optional[List[List[List[str]]]], index: int
) -> Optional[List[List[str]]]:
if tokens is None:
return None
return tokens[index]

def tokenize(self, *args, **kwargs):
"""
This functions will receive the inputs from Clients, usually there are
two possible inputs
1) a row of texts: List[str]
2) a row of pre-processed tokens: List[List[str]]
Override this function to be TorchScriptable, e.g you need to declare
concrete input arguments with type hints.
"""
raise NotImplementedError

def numberize(self, *args, **kwargs):
"""
This functions will receive the outputs from function: tokenize() or
will be called directly from PyTextTensorizer function: numberize().
Override this function to be TorchScriptable, e.g you need to declare
concrete input arguments with type hints.
"""
raise NotImplementedError

def tensorize(self, *args, **kwargs):
"""
This functions will receive a list(e.g a batch) of outputs
from function numberize(), padding and convert to output tensors.
Override this function to be TorchScriptable, e.g you need to declare
concrete input arguments with type hints.
"""
raise NotImplementedError


class Tensorizer(Component):
"""Tensorizers are a component that converts from batches of
`pytext.data.type.DataType` instances to tensors. These tensors will eventually
Expand All @@ -108,6 +184,7 @@ class Tensorizer(Component):

__COMPONENT_TYPE__ = ComponentType.TENSORIZER
__EXPANSIBLE__ = True
__TENSORIZER_SCRIPT_IMPL__ = None

class Config(Component.Config):
pass
Expand Down
Loading

0 comments on commit 39467dc

Please sign in to comment.