diff --git a/pytext/data/bert_tensorizer.py b/pytext/data/bert_tensorizer.py index 4acf936e6..388091c4a 100644 --- a/pytext/data/bert_tensorizer.py +++ b/pytext/data/bert_tensorizer.py @@ -1,30 +1,20 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -import itertools from typing import Any, Dict, List, Optional, Tuple import torch from fairseq.data.dictionary import Dictionary from fairseq.data.legacy.masked_lm_dictionary import BertDictionary from pytext.config.component import ComponentType, create_component -from pytext.data.tensorizers import Tensorizer, TensorizerScriptImpl, lookup_tokens +from pytext.data.tensorizers import Tensorizer, TensorizerScriptImpl from pytext.data.tokenizers import Tokenizer, WordPieceTokenizer -from pytext.data.utils import ( - BOS, - EOS, - MASK, - PAD, - UNK, - SpecialToken, - Vocabulary, - pad_and_tensorize, -) -from pytext.torchscript.tensorizer import ScriptBERTTensorizer +from pytext.data.utils import BOS, EOS, MASK, PAD, UNK, SpecialToken, Vocabulary from pytext.torchscript.tensorizer.tensorizer import VocabLookup from pytext.torchscript.utils import pad_2d, pad_2d_mask from pytext.torchscript.vocab import ScriptVocabulary from pytext.utils.file_io import PathManager +from pytext.utils.lazy import lazy_property def build_fairseq_vocab( @@ -294,74 +284,28 @@ def __init__( def column_schema(self): return [(column, str) for column in self.columns] - def _lookup_tokens(self, text: str, seq_len: int = None): - """ - This function knows how to call lookup_tokens with the correct - settings for this model. The default behavior is to wrap the - numberized text with distinct BOS and EOS tokens. The resulting - vector would look something like this: - [BOS, token1_id, . . . tokenN_id, EOS] - - The function also takes an optional seq_len parameter which is - used to customize truncation in case we have multiple text fields. - By default max_seq_len is used. It's upto the numberize function of - the class to decide how to use the seq_len param. - - For example: - - In the case of sentence pair classification, we might want both - pieces of text have the same length which is half of the - max_seq_len supported by the model. - - In the case of QA, we might want to truncate the context by a - seq_len which is longer than what we use for the question. - """ - return lookup_tokens( - text, - tokenizer=self.tokenizer, - vocab=self.vocab, - bos_token=self.vocab.bos_token, - eos_token=self.vocab.eos_token, - max_seq_len=seq_len if seq_len else self.max_seq_len, + @lazy_property + def tensorizer_script_impl(self): + return self.__TENSORIZER_SCRIPT_IMPL__( + tokenizer=self.tokenizer, vocab=self.vocab, max_seq_len=self.max_seq_len ) - def _wrap_numberized_text( - self, numberized_sentences: List[List[str]] - ) -> List[List[str]]: - """ - If a class has a non-standard way of generating the final numberized text - (eg: BERT) then a class specific version of wrap_numberized_text function - should be implemented. This allows us to share the numberize - function across classes without having to copy paste code. The default - implementation doesnt do anything. - """ - return numberized_sentences - def numberize(self, row: Dict) -> Tuple[Any, ...]: """ This function contains logic for converting tokens into ids based on the specified vocab. It also outputs, for each instance, the vectors needed to run the actual model. """ - sentences = [self._lookup_tokens(row[column])[0] for column in self.columns] - sentences = self._wrap_numberized_text(sentences) - seq_lens = (len(sentence) for sentence in sentences) - segment_labels = ([i] * seq_len for i, seq_len in enumerate(seq_lens)) - tokens = list(itertools.chain(*sentences)) - segment_labels = list(itertools.chain(*segment_labels)) - seq_len = len(tokens) - positions = list(range(seq_len)) - # tokens, segment_label, seq_len - return tokens, segment_labels, seq_len, positions + per_sentence_tokens = [ + self.tokenizer.tokenize(row[column]) for column in self.columns + ] + return self.tensorizer_script_impl.numberize(per_sentence_tokens) def tensorize(self, batch) -> Tuple[torch.Tensor, ...]: """ Convert instance level vectors into batch level tensors. """ - tokens, segment_labels, seq_lens, positions = zip(*batch) - tokens = pad_and_tensorize(tokens, self.vocab.get_pad_index()) - pad_mask = (tokens != self.vocab.get_pad_index()).long() - segment_labels = pad_and_tensorize(segment_labels) - positions = pad_and_tensorize(positions) - return tokens, pad_mask, segment_labels, positions + return self.tensorizer_script_impl.tensorize_wrapper(*zip(*batch)) def initialize(self, vocab_builder=None, from_scratch=True): # vocab for BERT is already set @@ -455,31 +399,3 @@ def __init__( super().__init__( columns=columns, vocab=vocab, tokenizer=tokenizer, max_seq_len=max_seq_len ) - - def _lookup_tokens(self, text: str, seq_len: int = None): - return lookup_tokens( - text, - tokenizer=self.tokenizer, - vocab=self.vocab, - bos_token=None, - eos_token=self.vocab.eos_token, - max_seq_len=seq_len if seq_len else self.max_seq_len, - ) - - def _wrap_numberized_text( - self, numberized_sentences: List[List[str]] - ) -> List[List[str]]: - numberized_sentences[0] = [self.vocab.get_bos_index()] + numberized_sentences[0] - return numberized_sentences - - def torchscriptify(self): - return ScriptBERTTensorizer( - tokenizer=self.tokenizer.torchscriptify(), - vocab=ScriptVocabulary( - list(self.vocab), - pad_idx=self.vocab.get_pad_index(), - bos_idx=self.vocab.get_bos_index(), - eos_idx=self.vocab.get_eos_index(), - ), - max_seq_len=self.max_seq_len, - ) diff --git a/pytext/data/roberta_tensorizer.py b/pytext/data/roberta_tensorizer.py index ce9c50ed7..e71207871 100644 --- a/pytext/data/roberta_tensorizer.py +++ b/pytext/data/roberta_tensorizer.py @@ -9,8 +9,6 @@ ) from pytext.data.tokenizers import GPT2BPETokenizer, Tokenizer from pytext.data.utils import BOS, EOS, MASK, PAD, UNK -from pytext.torchscript.tensorizer import ScriptRoBERTaTensorizer -from pytext.torchscript.vocab import ScriptVocabulary from pytext.utils.file_io import PathManager @@ -54,15 +52,3 @@ def from_config(cls, config: Config): max_seq_len=config.max_seq_len, base_tokenizer=base_tokenizer, ) - - def torchscriptify(self): - return ScriptRoBERTaTensorizer( - tokenizer=self.tokenizer.torchscriptify(), - vocab=ScriptVocabulary( - list(self.vocab), - pad_idx=self.vocab.get_pad_index(), - bos_idx=self.vocab.get_bos_index(), - eos_idx=self.vocab.get_eos_index(), - ), - max_seq_len=self.max_seq_len, - ) diff --git a/pytext/data/squad_for_bert_tensorizer.py b/pytext/data/squad_for_bert_tensorizer.py index 05d00131f..44eda3330 100644 --- a/pytext/data/squad_for_bert_tensorizer.py +++ b/pytext/data/squad_for_bert_tensorizer.py @@ -8,6 +8,7 @@ from pytext.config.component import ComponentType, create_component from pytext.data.bert_tensorizer import BERTTensorizer, build_fairseq_vocab from pytext.data.roberta_tensorizer import RoBERTaTensorizer +from pytext.data.tensorizers import lookup_tokens from pytext.data.tokenizers import Tokenizer from pytext.data.utils import BOS, EOS, MASK, PAD, UNK, Vocabulary, pad_and_tensorize from pytext.torchscript.tensorizer import ScriptRoBERTaTensorizerWithIndices @@ -51,6 +52,16 @@ def __init__( self.answers_column = answers_column self.answer_starts_column = answer_starts_column + def _lookup_tokens(self, text: str, seq_len: int = None): + return lookup_tokens( + text, + tokenizer=self.tokenizer, + vocab=self.vocab, + bos_token=None, + eos_token=self.vocab.eos_token, + max_seq_len=seq_len if seq_len else self.max_seq_len, + ) + def numberize(self, row): question_column, doc_column = self.columns doc_tokens, start_idx, end_idx = self._lookup_tokens(row[doc_column]) @@ -283,8 +294,15 @@ def __init__( self.answer_starts_column = answer_starts_column self.wrap_special_tokens = False - def _lookup_tokens(self, text): - return RoBERTaTensorizer._lookup_tokens(self, text) + def _lookup_tokens(self, text: str, seq_len: int = None): + return lookup_tokens( + text, + tokenizer=self.tokenizer, + vocab=self.vocab, + bos_token=self.vocab.bos_token, + eos_token=self.vocab.eos_token, + max_seq_len=seq_len if seq_len else self.max_seq_len, + ) def torchscriptify(self): return ScriptRoBERTaTensorizerWithIndices( diff --git a/pytext/data/tensorizers.py b/pytext/data/tensorizers.py index 8033e0dcb..1751221c7 100644 --- a/pytext/data/tensorizers.py +++ b/pytext/data/tensorizers.py @@ -2,6 +2,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import contextlib +import copy from typing import List, Optional import torch @@ -21,6 +22,7 @@ from pytext.torchscript.tensorizer import VectorNormalizer from pytext.utils import cuda, precision from pytext.utils.data import Slot +from pytext.utils.lazy import lazy_property from .utils import ( BOL, @@ -269,9 +271,21 @@ def initialize(self, from_scratch=True): # we need yield here to make this function a generator yield - def torchscriptify(self): + @lazy_property + def tensorizer_script_impl(self): + # Script tensorizer is unpickleable, we use lazy_property for + # lazy initialization to construct the object during run time. raise NotImplementedError + def __getstate__(self): + # make a shallow copy of state to avoid side effect on the original object + state = copy.copy(vars(self)) + state.pop("tensorizer_script_impl", None) + return state + + def torchscriptify(self): + return self.tensorizer_script_impl.torchscriptify() + class VocabFileConfig(Component.Config): #: File containing tokens to add to vocab (first whitespace-separated entry per diff --git a/pytext/data/xlm_tensorizer.py b/pytext/data/xlm_tensorizer.py index b09b59122..8ebabbdca 100644 --- a/pytext/data/xlm_tensorizer.py +++ b/pytext/data/xlm_tensorizer.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -import itertools from typing import Any, Dict, List, Optional, Tuple import torch @@ -12,13 +11,12 @@ BERTTensorizerBaseScriptImpl, build_fairseq_vocab, ) -from pytext.data.tensorizers import lookup_tokens from pytext.data.tokenizers import Tokenizer from pytext.data.utils import EOS, MASK, PAD, UNK, Vocabulary from pytext.data.xlm_constants import LANG2ID_15 -from pytext.torchscript.tensorizer import ScriptXLMTensorizer from pytext.torchscript.vocab import ScriptVocabulary from pytext.utils.file_io import PathManager +from pytext.utils.lazy import lazy_property class XLMTensorizerScriptImpl(BERTTensorizerBaseScriptImpl): @@ -198,6 +196,19 @@ def column_schema(self): schema += [(self.language_column, str)] return schema + @lazy_property + def tensorizer_script_impl(self): + languages = [0] * (max(list(self.lang2id.values())) + 1) + for k, v in self.lang2id.items(): + languages[v] = k + return self.__TENSORIZER_SCRIPT_IMPL__( + tokenizer=self.tokenizer, + vocab=self.vocab, + language_vocab=languages, + max_seq_len=self.max_seq_len, + default_language=self.default_language, + ) + def get_lang_id(self, row: Dict, col: str) -> int: # generate lang embeddings. if training without lang embeddings, use # the first language as the lang_id (there will always be one lang) @@ -210,54 +221,13 @@ def get_lang_id(self, row: Dict, col: str) -> int: # use En as default return self.lang2id.get(self.default_language, 0) - def _lookup_tokens(self, text: str, seq_len: int) -> List[str]: - return lookup_tokens( - text, - tokenizer=self.tokenizer, - vocab=self.vocab, - bos_token=self.vocab.eos_token, - eos_token=self.vocab.eos_token, - use_eos_token_for_bos=True, - max_seq_len=seq_len, - ) - def numberize(self, row: Dict) -> Tuple[Any, ...]: - sentences = [] - language_column = self.language_column - columns = self.columns - - # sequence_length is adjusted based on the number of text fields and needs - # to account for the special tokens which we will be wrapping - seq_len = self.max_seq_len // len(columns) - sentences = [ - self._lookup_tokens(row[column], seq_len)[0] for column in self.columns + per_sentence_tokens = [ + self.tokenizer.tokenize(row[column]) for column in self.columns ] - seq_lens = [len(sentence) for sentence in sentences] - lang_ids = [self.get_lang_id(row, language_column)] * len(self.columns) - # expand the language ids to each token - lang_ids = ([lang_id] * seq_len for lang_id, seq_len in zip(lang_ids, seq_lens)) - - tokens = list(itertools.chain(*sentences)) - segment_labels = list(itertools.chain(*lang_ids)) - seq_len = len(tokens) - positions = [index for index in range(seq_len)] - return tokens, segment_labels, seq_len, positions - - def torchscriptify(self): - languages = [0] * (max(list(self.lang2id.values())) + 1) - for k, v in self.lang2id.items(): - languages[v] = k - - return ScriptXLMTensorizer( - tokenizer=self.tokenizer.torchscriptify(), - token_vocab=ScriptVocabulary( - list(self.vocab), - pad_idx=self.vocab.get_pad_index(), - bos_idx=self.vocab.get_eos_index(), - eos_idx=self.vocab.get_eos_index(), - unk_idx=self.vocab.get_unk_index(), - ), - language_vocab=ScriptVocabulary(languages), - max_seq_len=self.max_seq_len, - default_language=self.default_language, + per_sentence_languages = [self.get_lang_id(row, self.language_column)] * len( + self.columns + ) + return self.tensorizer_script_impl.numberize( + per_sentence_tokens, per_sentence_languages ) diff --git a/pytext/torchscript/module.py b/pytext/torchscript/module.py index 107bb4084..d7d93552e 100644 --- a/pytext/torchscript/module.py +++ b/pytext/torchscript/module.py @@ -37,7 +37,7 @@ def __init__( @torch.jit.script_method def forward(self, texts: List[str]): - input_tensors = self.tensorizer.tensorize(texts=squeeze_1d(texts)) + input_tensors = self.tensorizer(texts=squeeze_1d(texts)) logits = self.model(input_tensors) return self.output_layer(logits) @@ -56,7 +56,7 @@ def __init__( @torch.jit.script_method def forward(self, tokens: List[List[str]]): - input_tensors = self.tensorizer.tensorize(tokens=squeeze_2d(tokens)) + input_tensors = self.tensorizer(pre_tokenized=squeeze_2d(tokens)) logits = self.model(input_tensors) return self.output_layer(logits) @@ -75,8 +75,8 @@ def __init__( @torch.jit.script_method def forward(self, tokens: List[List[str]], languages: Optional[List[str]] = None): - input_tensors = self.tensorizer.tensorize( - tokens=squeeze_2d(tokens), languages=squeeze_1d(languages) + input_tensors = self.tensorizer( + pre_tokenized=squeeze_2d(tokens), languages=squeeze_1d(languages) ) logits = self.model(input_tensors) return self.output_layer(logits) @@ -101,8 +101,8 @@ def forward( dense_feat: List[List[float]], languages: Optional[List[str]] = None, ): - input_tensors = self.tensorizer.tensorize( - tokens=squeeze_2d(tokens), languages=squeeze_1d(languages) + input_tensors = self.tensorizer( + pre_tokenized=squeeze_2d(tokens), languages=squeeze_1d(languages) ) logits = self.model(input_tensors, torch.tensor(dense_feat).float()) return self.output_layer(logits) diff --git a/pytext/utils/lazy.py b/pytext/utils/lazy.py index 0638ed40d..0daf03cf7 100644 --- a/pytext/utils/lazy.py +++ b/pytext/utils/lazy.py @@ -8,6 +8,26 @@ from torch import nn +class lazy_property(object): + """ + More or less copy-pasta: http://stackoverflow.com/a/6849299 + Meant to be used for lazy evaluation of an object attribute. + property should represent non-mutable data, as it replaces itself. + """ + + def __init__(self, fget): + self._fget = fget + self.__doc__ = fget.__doc__ + self.__name__ = fget.__name__ + + def __get__(self, obj, obj_cls_type): + if obj is None: + return None + value = self._fget(obj) + setattr(obj, self.__name__, value) + return value + + class UninitializedLazyModuleError(Exception): """A lazy module was used improperly."""