Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.

Commit

Permalink
Use TensorizerImpl for both training and inference for BERT, RoBERTa …
Browse files Browse the repository at this point in the history
…and XLM tensorizer (#1195)

Summary:
Pull Request resolved: #1195

Use TensorizerImpl for both training and inference for BERT, RoBERTa and XLM tensorizer

Differential Revision: D18693298

fbshipit-source-id: 9e6910a80207c5c2aa7b6fc4550a851054ebb5a7
  • Loading branch information
chenyangyu1988 authored and facebook-github-bot committed Dec 17, 2019
1 parent 779ba2b commit c511ac1
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 170 deletions.
108 changes: 12 additions & 96 deletions pytext/data/bert_tensorizer.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,20 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import itertools
from typing import Any, Dict, List, Optional, Tuple

import torch
from fairseq.data.dictionary import Dictionary
from fairseq.data.legacy.masked_lm_dictionary import BertDictionary
from pytext.config.component import ComponentType, create_component
from pytext.data.tensorizers import Tensorizer, TensorizerScriptImpl, lookup_tokens
from pytext.data.tensorizers import Tensorizer, TensorizerScriptImpl
from pytext.data.tokenizers import Tokenizer, WordPieceTokenizer
from pytext.data.utils import (
BOS,
EOS,
MASK,
PAD,
UNK,
SpecialToken,
Vocabulary,
pad_and_tensorize,
)
from pytext.torchscript.tensorizer import ScriptBERTTensorizer
from pytext.data.utils import BOS, EOS, MASK, PAD, UNK, SpecialToken, Vocabulary
from pytext.torchscript.tensorizer.tensorizer import VocabLookup
from pytext.torchscript.utils import pad_2d, pad_2d_mask
from pytext.torchscript.vocab import ScriptVocabulary
from pytext.utils.file_io import PathManager
from pytext.utils.lazy import lazy_property


def build_fairseq_vocab(
Expand Down Expand Up @@ -294,74 +284,28 @@ def __init__(
def column_schema(self):
return [(column, str) for column in self.columns]

def _lookup_tokens(self, text: str, seq_len: int = None):
"""
This function knows how to call lookup_tokens with the correct
settings for this model. The default behavior is to wrap the
numberized text with distinct BOS and EOS tokens. The resulting
vector would look something like this:
[BOS, token1_id, . . . tokenN_id, EOS]
The function also takes an optional seq_len parameter which is
used to customize truncation in case we have multiple text fields.
By default max_seq_len is used. It's upto the numberize function of
the class to decide how to use the seq_len param.
For example:
- In the case of sentence pair classification, we might want both
pieces of text have the same length which is half of the
max_seq_len supported by the model.
- In the case of QA, we might want to truncate the context by a
seq_len which is longer than what we use for the question.
"""
return lookup_tokens(
text,
tokenizer=self.tokenizer,
vocab=self.vocab,
bos_token=self.vocab.bos_token,
eos_token=self.vocab.eos_token,
max_seq_len=seq_len if seq_len else self.max_seq_len,
@lazy_property
def tensorizer_script_impl(self):
return self.__TENSORIZER_SCRIPT_IMPL__(
tokenizer=self.tokenizer, vocab=self.vocab, max_seq_len=self.max_seq_len
)

def _wrap_numberized_text(
self, numberized_sentences: List[List[str]]
) -> List[List[str]]:
"""
If a class has a non-standard way of generating the final numberized text
(eg: BERT) then a class specific version of wrap_numberized_text function
should be implemented. This allows us to share the numberize
function across classes without having to copy paste code. The default
implementation doesnt do anything.
"""
return numberized_sentences

def numberize(self, row: Dict) -> Tuple[Any, ...]:
"""
This function contains logic for converting tokens into ids based on
the specified vocab. It also outputs, for each instance, the vectors
needed to run the actual model.
"""
sentences = [self._lookup_tokens(row[column])[0] for column in self.columns]
sentences = self._wrap_numberized_text(sentences)
seq_lens = (len(sentence) for sentence in sentences)
segment_labels = ([i] * seq_len for i, seq_len in enumerate(seq_lens))
tokens = list(itertools.chain(*sentences))
segment_labels = list(itertools.chain(*segment_labels))
seq_len = len(tokens)
positions = list(range(seq_len))
# tokens, segment_label, seq_len
return tokens, segment_labels, seq_len, positions
per_sentence_tokens = [
self.tokenizer.tokenize(row[column]) for column in self.columns
]
return self.tensorizer_script_impl.numberize(per_sentence_tokens)

def tensorize(self, batch) -> Tuple[torch.Tensor, ...]:
"""
Convert instance level vectors into batch level tensors.
"""
tokens, segment_labels, seq_lens, positions = zip(*batch)
tokens = pad_and_tensorize(tokens, self.vocab.get_pad_index())
pad_mask = (tokens != self.vocab.get_pad_index()).long()
segment_labels = pad_and_tensorize(segment_labels)
positions = pad_and_tensorize(positions)
return tokens, pad_mask, segment_labels, positions
return self.tensorizer_script_impl.tensorize_wrapper(*zip(*batch))

def initialize(self, vocab_builder=None, from_scratch=True):
# vocab for BERT is already set
Expand Down Expand Up @@ -455,31 +399,3 @@ def __init__(
super().__init__(
columns=columns, vocab=vocab, tokenizer=tokenizer, max_seq_len=max_seq_len
)

def _lookup_tokens(self, text: str, seq_len: int = None):
return lookup_tokens(
text,
tokenizer=self.tokenizer,
vocab=self.vocab,
bos_token=None,
eos_token=self.vocab.eos_token,
max_seq_len=seq_len if seq_len else self.max_seq_len,
)

def _wrap_numberized_text(
self, numberized_sentences: List[List[str]]
) -> List[List[str]]:
numberized_sentences[0] = [self.vocab.get_bos_index()] + numberized_sentences[0]
return numberized_sentences

def torchscriptify(self):
return ScriptBERTTensorizer(
tokenizer=self.tokenizer.torchscriptify(),
vocab=ScriptVocabulary(
list(self.vocab),
pad_idx=self.vocab.get_pad_index(),
bos_idx=self.vocab.get_bos_index(),
eos_idx=self.vocab.get_eos_index(),
),
max_seq_len=self.max_seq_len,
)
14 changes: 0 additions & 14 deletions pytext/data/roberta_tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
)
from pytext.data.tokenizers import GPT2BPETokenizer, Tokenizer
from pytext.data.utils import BOS, EOS, MASK, PAD, UNK
from pytext.torchscript.tensorizer import ScriptRoBERTaTensorizer
from pytext.torchscript.vocab import ScriptVocabulary
from pytext.utils.file_io import PathManager


Expand Down Expand Up @@ -54,15 +52,3 @@ def from_config(cls, config: Config):
max_seq_len=config.max_seq_len,
base_tokenizer=base_tokenizer,
)

def torchscriptify(self):
return ScriptRoBERTaTensorizer(
tokenizer=self.tokenizer.torchscriptify(),
vocab=ScriptVocabulary(
list(self.vocab),
pad_idx=self.vocab.get_pad_index(),
bos_idx=self.vocab.get_bos_index(),
eos_idx=self.vocab.get_eos_index(),
),
max_seq_len=self.max_seq_len,
)
22 changes: 20 additions & 2 deletions pytext/data/squad_for_bert_tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytext.config.component import ComponentType, create_component
from pytext.data.bert_tensorizer import BERTTensorizer, build_fairseq_vocab
from pytext.data.roberta_tensorizer import RoBERTaTensorizer
from pytext.data.tensorizers import lookup_tokens
from pytext.data.tokenizers import Tokenizer
from pytext.data.utils import BOS, EOS, MASK, PAD, UNK, Vocabulary, pad_and_tensorize
from pytext.torchscript.tensorizer import ScriptRoBERTaTensorizerWithIndices
Expand Down Expand Up @@ -51,6 +52,16 @@ def __init__(
self.answers_column = answers_column
self.answer_starts_column = answer_starts_column

def _lookup_tokens(self, text: str, seq_len: int = None):
return lookup_tokens(
text,
tokenizer=self.tokenizer,
vocab=self.vocab,
bos_token=None,
eos_token=self.vocab.eos_token,
max_seq_len=seq_len if seq_len else self.max_seq_len,
)

def numberize(self, row):
question_column, doc_column = self.columns
doc_tokens, start_idx, end_idx = self._lookup_tokens(row[doc_column])
Expand Down Expand Up @@ -283,8 +294,15 @@ def __init__(
self.answer_starts_column = answer_starts_column
self.wrap_special_tokens = False

def _lookup_tokens(self, text):
return RoBERTaTensorizer._lookup_tokens(self, text)
def _lookup_tokens(self, text: str, seq_len: int = None):
return lookup_tokens(
text,
tokenizer=self.tokenizer,
vocab=self.vocab,
bos_token=self.vocab.bos_token,
eos_token=self.vocab.eos_token,
max_seq_len=seq_len if seq_len else self.max_seq_len,
)

def torchscriptify(self):
return ScriptRoBERTaTensorizerWithIndices(
Expand Down
16 changes: 15 additions & 1 deletion pytext/data/tensorizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import contextlib
import copy
from typing import List, Optional

import torch
Expand All @@ -21,6 +22,7 @@
from pytext.torchscript.tensorizer import VectorNormalizer
from pytext.utils import cuda, precision
from pytext.utils.data import Slot
from pytext.utils.lazy import lazy_property

from .utils import (
BOL,
Expand Down Expand Up @@ -269,9 +271,21 @@ def initialize(self, from_scratch=True):
# we need yield here to make this function a generator
yield

def torchscriptify(self):
@lazy_property
def tensorizer_script_impl(self):
# Script tensorizer is unpickleable, we use lazy_property for
# lazy initialization to construct the object during run time.
raise NotImplementedError

def __getstate__(self):
# make a shallow copy of state to avoid side effect on the original object
state = copy.copy(vars(self))
state.pop("tensorizer_script_impl", None)
return state

def torchscriptify(self):
return self.tensorizer_script_impl.torchscriptify()


class VocabFileConfig(Component.Config):
#: File containing tokens to add to vocab (first whitespace-separated entry per
Expand Down
72 changes: 21 additions & 51 deletions pytext/data/xlm_tensorizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import itertools
from typing import Any, Dict, List, Optional, Tuple

import torch
Expand All @@ -12,13 +11,12 @@
BERTTensorizerBaseScriptImpl,
build_fairseq_vocab,
)
from pytext.data.tensorizers import lookup_tokens
from pytext.data.tokenizers import Tokenizer
from pytext.data.utils import EOS, MASK, PAD, UNK, Vocabulary
from pytext.data.xlm_constants import LANG2ID_15
from pytext.torchscript.tensorizer import ScriptXLMTensorizer
from pytext.torchscript.vocab import ScriptVocabulary
from pytext.utils.file_io import PathManager
from pytext.utils.lazy import lazy_property


class XLMTensorizerScriptImpl(BERTTensorizerBaseScriptImpl):
Expand Down Expand Up @@ -198,6 +196,19 @@ def column_schema(self):
schema += [(self.language_column, str)]
return schema

@lazy_property
def tensorizer_script_impl(self):
languages = [0] * (max(list(self.lang2id.values())) + 1)
for k, v in self.lang2id.items():
languages[v] = k
return self.__TENSORIZER_SCRIPT_IMPL__(
tokenizer=self.tokenizer,
vocab=self.vocab,
language_vocab=languages,
max_seq_len=self.max_seq_len,
default_language=self.default_language,
)

def get_lang_id(self, row: Dict, col: str) -> int:
# generate lang embeddings. if training without lang embeddings, use
# the first language as the lang_id (there will always be one lang)
Expand All @@ -210,54 +221,13 @@ def get_lang_id(self, row: Dict, col: str) -> int:
# use En as default
return self.lang2id.get(self.default_language, 0)

def _lookup_tokens(self, text: str, seq_len: int) -> List[str]:
return lookup_tokens(
text,
tokenizer=self.tokenizer,
vocab=self.vocab,
bos_token=self.vocab.eos_token,
eos_token=self.vocab.eos_token,
use_eos_token_for_bos=True,
max_seq_len=seq_len,
)

def numberize(self, row: Dict) -> Tuple[Any, ...]:
sentences = []
language_column = self.language_column
columns = self.columns

# sequence_length is adjusted based on the number of text fields and needs
# to account for the special tokens which we will be wrapping
seq_len = self.max_seq_len // len(columns)
sentences = [
self._lookup_tokens(row[column], seq_len)[0] for column in self.columns
per_sentence_tokens = [
self.tokenizer.tokenize(row[column]) for column in self.columns
]
seq_lens = [len(sentence) for sentence in sentences]
lang_ids = [self.get_lang_id(row, language_column)] * len(self.columns)
# expand the language ids to each token
lang_ids = ([lang_id] * seq_len for lang_id, seq_len in zip(lang_ids, seq_lens))

tokens = list(itertools.chain(*sentences))
segment_labels = list(itertools.chain(*lang_ids))
seq_len = len(tokens)
positions = [index for index in range(seq_len)]
return tokens, segment_labels, seq_len, positions

def torchscriptify(self):
languages = [0] * (max(list(self.lang2id.values())) + 1)
for k, v in self.lang2id.items():
languages[v] = k

return ScriptXLMTensorizer(
tokenizer=self.tokenizer.torchscriptify(),
token_vocab=ScriptVocabulary(
list(self.vocab),
pad_idx=self.vocab.get_pad_index(),
bos_idx=self.vocab.get_eos_index(),
eos_idx=self.vocab.get_eos_index(),
unk_idx=self.vocab.get_unk_index(),
),
language_vocab=ScriptVocabulary(languages),
max_seq_len=self.max_seq_len,
default_language=self.default_language,
per_sentence_languages = [self.get_lang_id(row, self.language_column)] * len(
self.columns
)
return self.tensorizer_script_impl.numberize(
per_sentence_tokens, per_sentence_languages
)
Loading

0 comments on commit c511ac1

Please sign in to comment.