Skip to content

Commit

Permalink
Complete pytorch transformers interface, deprecate old GPT implement (#…
Browse files Browse the repository at this point in the history
…881)

* Rename namespaces to suppress warnings.

* Revert "Rename namespaces to suppress warnings."

This reverts commit 0cf7b23.

* Initial working-ish attempt.

* Intermediate check-in...

* More partial progress.

* Another pass...

* Fix sep/cls handling, cleanup.

* Further cleanup.

* Keyword name fix.

* Another flag fix.

* Pull debug print.

* Line length cleanup.

* WiC fix.

* Two task setup bugs.

* BoolQ typo

* Improved segment handling.

* Delete unused is_pair_task, other cleanup/fixes.

* Fix deleted path from merge.

* Fix cache path.

* relocate tasks from seminar

* add linguistic phenomena benchmark tasks

* Address (spurious?) tokenization warning.

* Select pool_type automatically to match model.

h/t Haokun Liu

* Config updates.

* Path fix

* add two prefix method and simple LM

* Fix XLNet UNK handling.

* Internal temporary MNLI alternate.

* Revert "Internal temporary MNLI alternate."

This reverts commit 455792a.

* refacor tags in data loader

* Add helper fn tests

* Finish merge

* Remove unused argument.

* update task init

* Possible ReCoRD bug fix

* Cleanup

* Fix merge issues.

* Revert "Remove unused argument."

This reverts commit 96a7c37.

* Assorted responses to Alex's commenst.

* Further ReCoRD fix.

* @iftenney's comments.

* Fix/simplify segment logic.

* @W4ngatang's comments

* Cleanup.

* add forward functinos

* bugfix

* merge pytorch transformer

* update old process split

* add gpt2

* add get_pretrained_lm_head for transformers

* update filename

* add config

* debug

* update config

* allow evaluate with raw parameter

* debug

* Cleanup

* Fix issues with alternative embeddings_mode settings, max_layer.

* More mix cleanup.

* Masking fix.

* cleanup

* simplify get_seg_ids

* debug

* related adjustments to add pytorch transformers

* pytorch transformer refactor

* formatting

* formatting

* debug

* TransformerXL fix

* update test script

* formatting again

* add note to transfo-xl

* debug

* update test script

* update test script

* tokenized_name change

* cleanup

* pool type fix

* config update

* Update defaults.conf

* rename use_pytorch_transformer

* cleanup

* Update test_preprocess.py

* Update test_checkpointing.py

* Update test_write_preds.py

* clean up

* debug

* name changes

* name changes

* update message

* name changes

* tokenizer name fix

* docstring changes

* name changes

* restore asserts

* add pair embedding for pytorch_transformers

* add max position embedding assert

* deal with gpt-like boundary fn

* roberta tokenizer support

* roberta model support

* roberta embedder

* fix roberta seg_id

* change unused_task_name message

* more test cases for pytorch_tranformers_interface

* gpt-style mirrored pair forward func for similarity tasks

* Update environment.yml

* adjust import location

* black

* move import location

* update test script

* add comments to test script

* update test script

* pool type fix

* tokenizer fix

* debug

* special tokens fix

* roberta vocab fix

* roberta tokenizer fix

* clean up

* Update test_pytorch_transformers_interface.py

* add_special_token fix

* black

* fix roberta message logic

* fix embedding extend bug

* black

* clean up

* simplify add_special_token fix

* add assert for lm task & pytorch_transformers

* black

* relocate task_modulator initialization

* minor changes

* rename task_modulator -> model_preprocessing_interface

* change lm_parsing process_split docstring

* black

* add gpt2-large

* update dependency

* update dependency for real

* clean up

* add a forgotten similarity task for gpt

* update setup

* update setup
  • Loading branch information
HaokunLiu authored and sleepinyourhat committed Aug 26, 2019
1 parent 815beea commit 6921e4d
Show file tree
Hide file tree
Showing 25 changed files with 1,117 additions and 700 deletions.
3 changes: 0 additions & 3 deletions .gitmodules
Expand Up @@ -2,6 +2,3 @@
path = jiant/modules/cove
url = https://github.com/salesforce/cove.git
ignore = untracked
[submodule "jiant/openai_transformer_lm/pytorch_huggingface"]
path = jiant/openai_transformer_lm/pytorch_huggingface
url = https://github.com/huggingface/pytorch-openai-transformer-lm.git
5 changes: 2 additions & 3 deletions environment.yml
Expand Up @@ -28,13 +28,12 @@ dependencies:
- python-Levenshtein==0.12.0
# for --remote_log functionality
- google-cloud-logging==1.11.0
# ftfy and spacy are used for GPT
- ftfy==5.4.1
# spacy is used for some tokenizers in pytorch-transformers
- spacy==2.0.11

# Warning: jiant currently depends on *both* pytorch_pretrained_bert > 0.6 _and_
# pytorch_transformers > 1.0. These are the same package, though the name changed between
# these two versions. AllenNLP requires 0.6 to support the BertAdam optimizer, and jiant
# directly requires 1.0 to support XLNet and WWM-BERT.
# This AllenNLP issue is relevant: https://github.com/allenai/allennlp/issues/3067
- pytorch-transformers==1.0.0
- pytorch-transformers==1.1.0
40 changes: 16 additions & 24 deletions jiant/config/defaults.conf
Expand Up @@ -237,16 +237,21 @@ input_module = "" // The word embedding or contextual word representation layer
// - elmo-chars-only: The dynamic CNN-based word embedding layer of AllenNLP's
// ELMo, but not ELMo's LSTM layer hidden states. Use with
// tokenizer = MosesTokenizer.
// - gpt: The OpenAI GPT language model encoder.
// Use with tokenizer = OpenAI.BPE.
// - bert-base-uncased, etc.: Any BERT model specifier that is valid for
// pytorch-pretrained-bert may be specified here. Use with
// tokenizer = ${input_module}
// We support the newer bert-large-uncased-whole-word-masking and
// bert-large-cased-whole-word-masking cased models, but they require
// the git development version of pytorch-pretrained-bert. To use these
// models, follow the instructions under 'From source' here:
// https://github.com/huggingface/pytorch-pretrained-BERT
// - bert-base-uncased, etc.: Any BERT model from pytorch_transformers.
// - roberta-base / roberta-large / roberta-large-mnli: RoBERTa model from
// pytorch_transformers.
// - xlnet-base-cased / xlnet-large-cased: XLNet Model from
// pytorch_transformers.
// - openai-gpt: The OpenAI GPT language model encoder from
// pytorch_transformers.
// - gpt2 / gpt2-medium / gpt2-large: The OpenAI GPT-2 language model encoder from
// pytorch_transformers.
// - transfo-xl-wt103: The Transformer-XL language model encoder from
// pytorch_transformers.
// - xlm-mlm-en-2048: XLM english language model encoder from
// pytorch_transformers.
// Note: Any input_module from pytorch_transformers requires
// tokenizer = ${input_module} or auto.

tokenizer = auto // The name of the tokenizer, passed to the Task constructor for
// appropriate handling during data loading. Currently supported
Expand All @@ -257,8 +262,7 @@ tokenizer = auto // The name of the tokenizer, passed to the Task constructor f
// - MosesTokenizer: Our standard word tokenizer. (Support for
// other NLTK tokenizers is pending.)
// - bert-uncased-base, etc.: Use the tokenizer supplied with
// pytorch-pretrained-bert that corresponds to that BERT model.
// - OpenAI.BPE: The tokenizer supplied with OpenAI GPT.
// pytorch_transformers that corresponds the input_module.

word_embs_file = ${WORD_EMBS_FILE} // Path to embeddings file, used with glove and fastText.
d_word = 300 // Dimension of word embeddings, used with scratch, glove, or fastText.
Expand All @@ -272,18 +276,6 @@ d_char = 100 // Dimension of trained char embeddings.
n_char_filters = 100 // Number of filters in trained char CNN.
char_filter_sizes = "2,3,4,5" // Size of char CNN filters.

openai_transformer_ckpt = "" // If non-empty, will load OpenAI Transformer from the given
// TensorFlow checkpoint. Checkpoint should be as created by the
// original release (openai/finetune-transformer-lm).
openai_embeddings_mode = "none" // How to handle the embedding layer of the OpenAI Transformer
// model:
// "none" or "top" returns only top-layer activation,
// "cat" returns top-layer concatenated with
// lexical layer,
// "only" returns only lexical layer,
// "mix" uses ELMo-style scalar mixing (with
// learned weights) across all layers.

pytorch_transformers_output_mode = "none" // How to handle the embedding layer of the
// BERT/XLNet model:
// "none" or "top" returns only top-layer activation,
Expand Down
4 changes: 1 addition & 3 deletions jiant/config/edgeprobe/edgeprobe_openai.conf
Expand Up @@ -23,10 +23,8 @@ write_preds = "val,test"
lr_patience = 5 // vals until LR decay
patience = 20 // vals until early-stopping

tokenizer = "OpenAI.BPE"
cove = 0
input_module = "gpt"
openai_transformer_ckpt = "" // use default weights
input_module = "openai-gpt"

// Use no-op encoder (no params).
sent_enc = "none"
Expand Down
107 changes: 82 additions & 25 deletions jiant/models.py
Expand Up @@ -59,6 +59,8 @@
STSBTask,
TaggingTask,
WiCTask,
MRPCTask,
QQPTask,
)
from jiant.utils import config
from jiant.utils.utils import (
Expand Down Expand Up @@ -227,31 +229,51 @@ def build_model(args, vocab, pretrained_embs, tasks):

# Build embeddings.
cove_layer = None
if args.input_module == "gpt":
# Note: incompatible with other embedders, but logic in preprocess.py
# should prevent these from being enabled anyway.
from .openai_transformer_lm.utils import OpenAIEmbedderModule

log.info("Using OpenAI transformer model.")
# Here, this uses openAIEmbedder.
embedder = OpenAIEmbedderModule(args)
d_emb = embedder.get_output_dim()
elif args.input_module.startswith("bert"):
if args.input_module.startswith("bert-"):
from jiant.pytorch_transformers_interface.modules import BertEmbedderModule

log.info(f"Using BERT model ({args.input_module}).")
embedder = BertEmbedderModule(args)
d_emb = embedder.get_output_dim()
elif args.input_module.startswith("xlnet"):
elif args.input_module.startswith("roberta-"):
from jiant.pytorch_transformers_interface.modules import RobertaEmbedderModule

log.info(f"Using RoBERTa model ({args.input_module}).")
embedder = RobertaEmbedderModule(args)
d_emb = embedder.get_output_dim()
elif args.input_module.startswith("xlnet-"):
from jiant.pytorch_transformers_interface.modules import XLNetEmbedderModule

log.info(f"Using XLNet model ({args.input_module}).")
embedder = XLNetEmbedderModule(args)
d_emb = embedder.get_output_dim()
elif args.input_module.startswith("openai-gpt"):
from jiant.pytorch_transformers_interface.modules import OpenAIGPTEmbedderModule

log.info(f"Using OpenAI GPT model ({args.input_module}).")
embedder = OpenAIGPTEmbedderModule(args)
d_emb = embedder.get_output_dim()
elif args.input_module.startswith("gpt2"):
from jiant.pytorch_transformers_interface.modules import GPT2EmbedderModule

log.info(f"Using GPT-2 model ({args.input_module}).")
embedder = GPT2EmbedderModule(args)
d_emb = embedder.get_output_dim()
elif args.input_module.startswith("transfo-xl-"):
from jiant.pytorch_transformers_interface.modules import TransfoXLEmbedderModule

log.info(f"Using Transformer-XL model ({args.input_module}).")
embedder = TransfoXLEmbedderModule(args)
d_emb = embedder.get_output_dim()
elif args.input_module.startswith("xlm-"):
from jiant.pytorch_transformers_interface.modules import XLMEmbedderModule

log.info(f"Using XLM model ({args.input_module}).")
embedder = XLMEmbedderModule(args)
d_emb = embedder.get_output_dim()
else:
# Default case, used for ELMo, CoVe, word embeddings, etc.
d_emb, embedder, cove_layer = build_embeddings(args, vocab, tasks, pretrained_embs)
d_sent_input = args.d_hid

sent_encoder, d_sent_output = build_sent_encoder(
args, vocab, d_emb, tasks, embedder, cove_layer
Expand Down Expand Up @@ -312,7 +334,6 @@ def build_embeddings(args, vocab, tasks, pretrained_embs=None):
word_embs = nn.Embedding(n_token_vocab, d_word).weight
else:
assert input_module_uses_pytorch_transformers(args.input_module) or args.input_module in [
"gpt",
"elmo",
"elmo-chars-only",
], f"'{args.input_module}' is not a valid value for input_module."
Expand Down Expand Up @@ -508,6 +529,12 @@ def build_task_specific_modules(task, model, d_sent, d_emb, vocab, embedder, arg
setattr(model, "%s_hid2voc" % task.name, hid2voc)
setattr(model, "%s_mdl" % task.name, hid2voc)
elif isinstance(task, LanguageModelingTask):
assert not input_module_uses_pytorch_transformers(args.input_module), (
"our LM Task does not support pytorch_transformers, if you need them, try to update",
"corresponding parts of the code. You may find get_pretrained_lm_head and",
"apply_lm_boundary_tokens from pytorch_transformer_interface.module useful,",
"do check if they are working correctly though.",
)
d_sent = args.d_hid + (args.skip_embs * d_emb)
hid2voc = build_lm(task, d_sent, args)
setattr(model, "%s_hid2voc" % task.name, hid2voc)
Expand Down Expand Up @@ -658,7 +685,7 @@ def build_pair_attn(d_in, d_hid_attn):

# Build the classifier
n_classes = task.n_classes if hasattr(task, "n_classes") else 1
if model.use_pytorch_transformers:
if model.uses_pair_embedding:
# BERT/XLNet handle pair tasks by concatenating the inputs and classifying the joined
# sequence, so we use a single sentence classifier
if isinstance(task, WiCTask):
Expand Down Expand Up @@ -751,9 +778,11 @@ def __init__(self, args, sent_encoder, vocab):
self.vocab = vocab
self.utilization = Average() if args.track_batch_utilization else None
self.elmo = args.input_module == "elmo"
self.use_pytorch_transformers = input_module_uses_pytorch_transformers(args.input_module)
self.uses_pair_embedding = input_module_uses_pair_embedding(args.input_module)
self.uses_mirrored_pair = input_module_uses_mirrored_pair(args.input_module)
self.project_before_pooling = not (
self.use_pytorch_transformers and args.transfer_paradigm == "finetune"
input_module_uses_pytorch_transformers(args.input_module)
and args.transfer_paradigm == "finetune"
) # Rough heuristic. TODO: Make this directly user-controllable.
self.sep_embs_for_skip = args.sep_embs_for_skip

Expand Down Expand Up @@ -860,7 +889,7 @@ def _nli_diagnostic_forward(self, batch, task, predict):

# embed the sentence
classifier = self._get_classifier(task)
if self.use_pytorch_transformers:
if self.uses_pair_embedding:
sent, mask = self.sent_encoder(batch["inputs"], task)
logits = classifier(sent, mask)
else:
Expand Down Expand Up @@ -898,7 +927,13 @@ def _pair_sentence_forward(self, batch, task, predict):

# embed the sentence
classifier = self._get_classifier(task)
if self.use_pytorch_transformers:
if isinstance(task, (MRPCTask, STSBTask, QQPTask)) and self.uses_mirrored_pair:
# Mirrored pair is a trick used by GPT-like models in similarity tasks
# TODO: Wic also falls into this type, although GPT paper didn't expeirment with this task
sent, mask = self.sent_encoder(batch["inputs"], task)
sent_m, mask_m = self.sent_encoder(batch["inputs_m"], task)
logits = classifier(sent, mask) + classifier(sent_m, mask_m)
elif self.uses_pair_embedding:
sent, mask = self.sent_encoder(batch["inputs"], task)
# special case for WiC b/c we want to add representations of particular tokens
if isinstance(task, WiCTask):
Expand Down Expand Up @@ -1057,12 +1092,11 @@ def _mc_forward(self, batch, task, predict):

logits = []
module = self._get_classifier(task)
if self.use_pytorch_transformers:
if self.uses_pair_embedding:
for choice_idx in range(task.n_choices):
sent, mask = self.sent_encoder(batch["choice%d" % choice_idx], task)
logit = module(sent, mask)
logits.append(logit)
out["n_exs"] = batch["choice0"]["pytorch_transformers_wpm_pretokenized"].size(0)
else:
ctx, ctx_mask = self.sent_encoder(batch["question"], task)
for choice_idx in range(task.n_choices):
Expand All @@ -1071,9 +1105,9 @@ def _mc_forward(self, batch, task, predict):
inp_mask = torch.cat([ctx_mask, mask], dim=1)
logit = module(inp, inp_mask)
logits.append(logit)
out["n_exs"] = batch["choice0"]["words"].size(0)
logits = torch.cat(logits, dim=1)
out["logits"] = logits
out["n_exs"] = get_batch_size(batch, keyword="choice0")

if "label" in batch:
labels = batch["label"]
Expand Down Expand Up @@ -1128,12 +1162,12 @@ def _multiple_choice_reading_comprehension_forward(self, batch, task, predict):
"""
out = {}
classifier = self._get_classifier(task)
if self.use_pytorch_transformers:
if self.uses_pair_embedding:
# if using BERT/XLNet, we concatenate the passage, question, and answer
inp = batch["psg_qst_ans"]
ex_embs, ex_mask = self.sent_encoder(inp, task)
logits = classifier(ex_embs, ex_mask)
out["n_exs"] = inp["pytorch_transformers_wpm_pretokenized"].size(0)
out["n_exs"] = get_batch_size(batch, keyword="psg_qst_ans")
else:
# else, we embed each independently and concat them
psg_emb, psg_mask = self.sent_encoder(batch["psg"], task)
Expand All @@ -1143,11 +1177,11 @@ def _multiple_choice_reading_comprehension_forward(self, batch, task, predict):
ans_emb, ans_mask = self.sent_encoder(batch["ans"], task)
inp = torch.cat([psg_emb, qst_emb, ans_emb], dim=1)
inp_mask = torch.cat([psg_mask, qst_mask, ans_mask], dim=1)
out["n_exs"] = batch["ans"]["words"].size(0)
out["n_exs"] = get_batch_size(batch, keyword="ans")
else: # ReCoRD inserts answer into the query
inp = torch.cat([psg_emb, qst_emb], dim=1)
inp_mask = torch.cat([psg_mask, qst_mask], dim=1)
out["n_exs"] = batch["qst"]["words"].size(0)
out["n_exs"] = get_batch_size(batch, keyword="qst")

logits = classifier(inp, inp_mask)
out["logits"] = logits
Expand Down Expand Up @@ -1196,3 +1230,26 @@ def get_elmo_mixing_weights(self, tasks=[]):
self.sent_encoder._text_field_embedder, task=None
)
return params


def input_module_uses_pair_embedding(input_module):
"""
This function tells whether the input module concatenate the two sentences in a pair when
running on pair tasks, like what GPT / BERT do on MNLI.
It seems redundant now, but it allows us to load similar models from other sources later on
"""
from jiant.pytorch_transformers_interface import input_module_uses_pytorch_transformers

return input_module_uses_pytorch_transformers(input_module)


def input_module_uses_mirrored_pair(input_module):
"""
This function tells whether the input model uses raw pair and mirrored pair simutaneously when
running on symmetrical pair tasks, like what GPT do on STS-B
"""
return (
input_module.startswith("openai-gpt")
or input_module.startswith("gpt2")
or input_module.startswith("transfo-xl-")
)
3 changes: 2 additions & 1 deletion jiant/modules/sentence_encoder.py
Expand Up @@ -84,6 +84,8 @@ def forward(self, sent, task, reset=True):
self.reset_states()

# General sentence embeddings (for sentence encoder).
# Make sent_mask first, pytorch_transformers text_field_embedder will change the token index
sent_mask = util.get_text_field_mask(sent).float()
# Skip this for probing runs that don't need it.
if not isinstance(self._phrase_layer, NullPhraseLayer):
word_embs_in_context = self._highway_layer(self._text_field_embedder(sent))
Expand Down Expand Up @@ -133,7 +135,6 @@ def forward(self, sent, task, reset=True):
task_word_embs_in_context = self._dropout(task_word_embs_in_context)

# The rest of the model
sent_mask = util.get_text_field_mask(sent).float()
sent_lstm_mask = sent_mask if self._mask_lstms else None
if word_embs_in_context is not None:
if isinstance(self._phrase_layer, ONLSTMStack) or isinstance(self._phrase_layer, PRPN):
Expand Down
Empty file.
1 change: 0 additions & 1 deletion jiant/openai_transformer_lm/pytorch_huggingface
Submodule pytorch_huggingface deleted from bfd8e0

0 comments on commit 6921e4d

Please sign in to comment.