From 1551a812b943fa67c0b57de8e7130b7d9a0a703e Mon Sep 17 00:00:00 2001 From: dpressel Date: Tue, 1 Mar 2022 20:33:04 -0500 Subject: [PATCH 1/4] add SPM and repl options to MLM example --- mead/api_examples/generate_mlm.py | 53 +++++++++++++++++++++++++------ 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/mead/api_examples/generate_mlm.py b/mead/api_examples/generate_mlm.py index c6a582760..961a3aacb 100644 --- a/mead/api_examples/generate_mlm.py +++ b/mead/api_examples/generate_mlm.py @@ -9,7 +9,7 @@ from eight_mile.pytorch.serialize import tlm_load_state_dict, load_tlm_npz from baseline.pytorch.lm import TransformerMaskedLanguageModel from eight_mile.utils import str2bool, read_json, Offsets, revlut -from baseline.vectorizers import Token1DVectorizer, BPEVectorizer1D +from baseline.vectorizers import Token1DVectorizer, BPEVectorizer1D, WordpieceVectorizer1D from baseline.pytorch.embeddings import * from mead.api_examples.transformer_utils import find_latest_checkpoint logger = logging.getLogger(__file__) @@ -85,12 +85,22 @@ def create_model(embeddings, d_model, d_ff, num_heads, num_layers, rpr_k, rpr_va return model +def get_subword_vec1d(type): + if type == 'bpe': + return BPEVectorizer1D + elif type == 'wordpiece': + return WordpieceVectorizer1D + else: + from baseline.vectorizers import SentencePieceVectorizer1D + return SentencePieceVectorizer1D + + def main(): parser = ArgumentParser() parser.add_argument("--basedir", type=str) parser.add_argument("--checkpoint", type=str, help='Checkpoint name or directory to load') parser.add_argument("--sample", type=str2bool, help='Sample from the decoder? Defaults to `false`', default=0) - parser.add_argument("--query", type=str, default='hello , are you today ?') + parser.add_argument("--query", type=str) parser.add_argument("--dataset_cache", type=str, default=os.path.expanduser('~/.bl-data'), help="Path or url of the dataset cache") parser.add_argument("--d_model", type=int, default=512, help="Model dimension (and embedding dsz)") @@ -101,8 +111,9 @@ def main(): parser.add_argument("--nctx", type=int, default=128, help="Max context length (for both encoder and decoder)") parser.add_argument("--embed_type", type=str, default='default', help="register label of the embeddings, so far support positional or learned-positional") - parser.add_argument("--subword_model_file", type=str, required=True) - parser.add_argument("--subword_vocab_file", type=str, required=True) + parser.add_argument("--subword_model_file", type=str, required=False) + parser.add_argument("--subword_vocab_file", type=str, required=False) + parser.add_argument("--subword_type", type=str, choices=["bpe", "wordpiece", "sentencepiece"], default="bpe") parser.add_argument("--use_cls", type=str2bool, default=False) parser.add_argument("--rpr_value_on", type=str2bool, default=False) parser.add_argument('--end_token', default='') @@ -115,8 +126,10 @@ def main(): parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") + parser.add_argument("--begin_token", default=["[CLS]"]) args = parser.parse_args() + if torch.cuda.device_count() == 1: torch.cuda.set_device(0) args.device = torch.device("cuda", 0) @@ -128,9 +141,11 @@ def main(): else: checkpoint = args.checkpoint - cls = None if not args.use_cls else '[CLS]' - end = args.end_token - vectorizer = BPEVectorizer1D(model_file=args.subword_model_file, vocab_file=args.subword_vocab_file, mxlen=args.nctx, emit_begin_tok=cls, emit_end_tok=end, extra_tokens=args.extra_tokens) + Vec1D = get_subword_vec1d(args.subword_type) + vectorizer = Vec1D(model_file=args.subword_model_file, vocab_file=args.subword_vocab_file, + mxlen=args.nctx, emit_begin_tok=args.begin_token, emit_end_tok=args.end_token, extra_tokens=args.extra_tokens) + + #vectorizer = BPEVectorizer1D(model_file=args.subword_model_file, vocab_file=args.subword_vocab_file, mxlen=args.nctx, emit_begin_tok=cls, emit_end_tok=end, extra_tokens=args.extra_tokens) vocab = vectorizer.vocab.copy() # If we are not using chars, then use 'x' for both input and output preproc_data = baseline.embeddings.load_embeddings('x', dsz=args.d_model, counts=False, known_vocab=vocab, embed_type=args.embed_type, preserve_vocab_indices=True) @@ -140,10 +155,28 @@ def main(): rpr_k=args.rpr_k, rpr_value_on=args.rpr_value_on, d_k=args.d_k, checkpoint_name=checkpoint, activation=args.activation) model.to(args.device) + index2word = revlut(vocab) - print('[Query]', args.query) - bpe_out = decode_sentence(model, vectorizer, args.query.split(), vocab, index2word, args.device, sample=args.sample, y_only=args.y_only) - print('[Response]', ' '.join(bpe_out)) + if args.query: + print('[Query]', args.query) + bpe_out = decode_sentence(model, vectorizer, args.query.split(), vocab, index2word, args.device, sample=args.sample, y_only=args.y_only) + print('[Response]', ' '.join(bpe_out)) + return + + from prompt_toolkit import prompt + + from prompt_toolkit.history import FileHistory + prompt_name='->> ' + history_file='.history' + history = FileHistory(history_file) + while True: + query = prompt(prompt_name, history=history) + query = query.strip() + if query == 'quit': + break + print('[Query]', query) + bpe_out = decode_sentence(model, vectorizer, query.split(), vocab, index2word, args.device, sample=args.sample, y_only=args.y_only) + print('[Response]', ' '.join(bpe_out)) main() From 8c43ea3d5f9a205694111595118a109bbb29ff44 Mon Sep 17 00:00:00 2001 From: dpressel Date: Sun, 6 Mar 2022 21:52:42 -0500 Subject: [PATCH 2/4] more repl tricks --- mead/api_examples/generate_mlm.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/mead/api_examples/generate_mlm.py b/mead/api_examples/generate_mlm.py index 961a3aacb..c8f1ef89f 100644 --- a/mead/api_examples/generate_mlm.py +++ b/mead/api_examples/generate_mlm.py @@ -52,7 +52,8 @@ def decode_sentence(model, vectorizer, query, word2index, index2word, device, sa return words -def create_model(embeddings, d_model, d_ff, num_heads, num_layers, rpr_k, rpr_value_on, d_k, checkpoint_name, activation): +def create_model(embeddings, d_model, d_ff, num_heads, num_layers, rpr_k, rpr_value_on, d_k, checkpoint_name, + activation, layer_norm_eps, layer_norms_after, embeddings_reduction): rpr_k = listify(rpr_k) if len(rpr_k) == 0 or rpr_k[0] < 1: @@ -63,6 +64,7 @@ def create_model(embeddings, d_model, d_ff, num_heads, num_layers, rpr_k, rpr_va logger.info("Creating tied encoder decoder model") model = TransformerMaskedLanguageModel.create({'x': embeddings}, hsz=d_model, + embeddings_reduction=embeddings_reduction, d_ff=d_ff, tie_weights=True, dropout=0, @@ -72,6 +74,8 @@ def create_model(embeddings, d_model, d_ff, num_heads, num_layers, rpr_k, rpr_va rpr_k=rpr_k, rpr_value_on=rpr_value_on, d_k=d_k, + layer_norm_eps=layer_norm_eps, + layer_norms_after=layer_norms_after, activation=activation, src_keys=['x'], tgt_key='x') if checkpoint_name.endswith('npz'): @@ -114,9 +118,12 @@ def main(): parser.add_argument("--subword_model_file", type=str, required=False) parser.add_argument("--subword_vocab_file", type=str, required=False) parser.add_argument("--subword_type", type=str, choices=["bpe", "wordpiece", "sentencepiece"], default="bpe") - parser.add_argument("--use_cls", type=str2bool, default=False) parser.add_argument("--rpr_value_on", type=str2bool, default=False) parser.add_argument('--end_token', default='') + parser.add_argument('--begin_token', default='[CLS]') + parser.add_argument('--embeddings_reduction', default='sum') + parser.add_argument("--layer_norms_after", type=str2bool, default=False, help="Layer norms after (set True for BERT)") + parser.add_argument('--layer_norm_eps', default=1e-6, type=float) parser.add_argument("--activation", type=str, default='gelu') parser.add_argument('--rpr_k', help='Relative attention positional sizes pass 0 if you dont want relative attention', type=int, default=[8], nargs='+') @@ -126,15 +133,8 @@ def main(): parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") - parser.add_argument("--begin_token", default=["[CLS]"]) args = parser.parse_args() - - if torch.cuda.device_count() == 1: - torch.cuda.set_device(0) - args.device = torch.device("cuda", 0) - - if os.path.isdir(args.checkpoint): checkpoint, _ = find_latest_checkpoint(args.checkpoint) logger.warning("Found latest checkpoint %s", checkpoint) @@ -152,7 +152,10 @@ def main(): embeddings = preproc_data['embeddings'] vocab = preproc_data['vocab'] model = create_model(embeddings, d_model=args.d_model, d_ff=args.d_ff, num_heads=args.num_heads, num_layers=args.num_layers, - rpr_k=args.rpr_k, rpr_value_on=args.rpr_value_on, d_k=args.d_k, checkpoint_name=checkpoint, activation=args.activation) + rpr_k=args.rpr_k, rpr_value_on=args.rpr_value_on, + d_k=args.d_k, checkpoint_name=checkpoint, activation=args.activation, + layer_norm_eps=args.layer_norm_eps, layer_norms_after=args.layer_norms_after, + embeddings_reduction=args.embeddings_reduction) model.to(args.device) @@ -175,6 +178,14 @@ def main(): query = query.strip() if query == 'quit': break + if query == ':sample': + args.sample = True + print("Turn sampling mode on") + continue + if query == ':max': + args.sample = False + print("Turn sampling mode off") + continue print('[Query]', query) bpe_out = decode_sentence(model, vectorizer, query.split(), vocab, index2word, args.device, sample=args.sample, y_only=args.y_only) print('[Response]', ' '.join(bpe_out)) From fc1a10c6a2c8448a4e0f359015ff99a65f569bcf Mon Sep 17 00:00:00 2001 From: dpressel Date: Mon, 7 Mar 2022 10:51:50 -0500 Subject: [PATCH 3/4] add support for persisted output bias --- mead/api_examples/convert_hf2npz.py | 2 ++ mead/api_examples/generate_mlm.py | 9 +++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mead/api_examples/convert_hf2npz.py b/mead/api_examples/convert_hf2npz.py index 32f314e8b..b7bc2e24e 100644 --- a/mead/api_examples/convert_hf2npz.py +++ b/mead/api_examples/convert_hf2npz.py @@ -106,6 +106,8 @@ def create_transformer_lm(config_url: str) -> Tuple[TransformerMaskedLanguageMod embeddings_dropout=pdrop, dropout=pdrop, activation=activation, + output_bias=True, + layer_norm_eps=layer_norm_eps, layer_norms_after=True, embeddings_reduction='sum-layer-norm') return model, num_layers diff --git a/mead/api_examples/generate_mlm.py b/mead/api_examples/generate_mlm.py index c8f1ef89f..6f9f23d9c 100644 --- a/mead/api_examples/generate_mlm.py +++ b/mead/api_examples/generate_mlm.py @@ -53,7 +53,7 @@ def decode_sentence(model, vectorizer, query, word2index, index2word, device, sa def create_model(embeddings, d_model, d_ff, num_heads, num_layers, rpr_k, rpr_value_on, d_k, checkpoint_name, - activation, layer_norm_eps, layer_norms_after, embeddings_reduction): + activation, layer_norm_eps, layer_norms_after, embeddings_reduction, output_bias): rpr_k = listify(rpr_k) if len(rpr_k) == 0 or rpr_k[0] < 1: @@ -77,6 +77,7 @@ def create_model(embeddings, d_model, d_ff, num_heads, num_layers, rpr_k, rpr_va layer_norm_eps=layer_norm_eps, layer_norms_after=layer_norms_after, activation=activation, + output_bias=output_bias, src_keys=['x'], tgt_key='x') if checkpoint_name.endswith('npz'): load_tlm_npz(model, checkpoint_name) @@ -114,13 +115,14 @@ def main(): parser.add_argument("--num_layers", type=int, default=8, help="Number of layers") parser.add_argument("--nctx", type=int, default=128, help="Max context length (for both encoder and decoder)") parser.add_argument("--embed_type", type=str, default='default', - help="register label of the embeddings, so far support positional or learned-positional") + help="register label of the embeddings") parser.add_argument("--subword_model_file", type=str, required=False) parser.add_argument("--subword_vocab_file", type=str, required=False) parser.add_argument("--subword_type", type=str, choices=["bpe", "wordpiece", "sentencepiece"], default="bpe") parser.add_argument("--rpr_value_on", type=str2bool, default=False) parser.add_argument('--end_token', default='') parser.add_argument('--begin_token', default='[CLS]') + parser.add_argument('--output_bias', default=False, type=str2bool) parser.add_argument('--embeddings_reduction', default='sum') parser.add_argument("--layer_norms_after", type=str2bool, default=False, help="Layer norms after (set True for BERT)") parser.add_argument('--layer_norm_eps', default=1e-6, type=float) @@ -145,7 +147,6 @@ def main(): vectorizer = Vec1D(model_file=args.subword_model_file, vocab_file=args.subword_vocab_file, mxlen=args.nctx, emit_begin_tok=args.begin_token, emit_end_tok=args.end_token, extra_tokens=args.extra_tokens) - #vectorizer = BPEVectorizer1D(model_file=args.subword_model_file, vocab_file=args.subword_vocab_file, mxlen=args.nctx, emit_begin_tok=cls, emit_end_tok=end, extra_tokens=args.extra_tokens) vocab = vectorizer.vocab.copy() # If we are not using chars, then use 'x' for both input and output preproc_data = baseline.embeddings.load_embeddings('x', dsz=args.d_model, counts=False, known_vocab=vocab, embed_type=args.embed_type, preserve_vocab_indices=True) @@ -155,7 +156,7 @@ def main(): rpr_k=args.rpr_k, rpr_value_on=args.rpr_value_on, d_k=args.d_k, checkpoint_name=checkpoint, activation=args.activation, layer_norm_eps=args.layer_norm_eps, layer_norms_after=args.layer_norms_after, - embeddings_reduction=args.embeddings_reduction) + embeddings_reduction=args.embeddings_reduction, output_bias=args.output_bias) model.to(args.device) From f479cde6a7a2c3ef7cd943ad281eae62e779493a Mon Sep 17 00:00:00 2001 From: dpressel Date: Mon, 7 Mar 2022 10:57:36 -0500 Subject: [PATCH 4/4] cleanup keys, add bert output bias --- layers/eight_mile/pytorch/serialize.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/layers/eight_mile/pytorch/serialize.py b/layers/eight_mile/pytorch/serialize.py index d06237d5f..78891fc36 100644 --- a/layers/eight_mile/pytorch/serialize.py +++ b/layers/eight_mile/pytorch/serialize.py @@ -91,6 +91,7 @@ 'bert.embeddings.LayerNorm.gamma': 'embeddings.reduction.ln.weight', 'bert.embeddings.LayerNorm.bias': 'embeddings.reduction.ln.bias', 'bert.embeddings.LayerNorm.weight': 'embeddings.reduction.ln.weight', + 'bert.cls.predictions.bias': 'output_layer.bias' } ROBERTA_HF_LAYER_MAP = { @@ -162,12 +163,16 @@ def convert_transformers_keys(num_layers: int, d: Dict, nested_layer_map: Dict = try: m[v.format(i)] = d[k.format(i)] except: - print(f"Bad key. Skipping {k.format(i)}") + # If its called alpha and beta, this key will be skipped and is not error worthy + if not 'LayerNorm.weight' in k and not 'LayerNorm.bias' in k: + print(f"Bad key. Skipping {k.format(i)}") for k, v in flat_map.items(): try: m[v] = d[k] except: - print(f"Bad key. Skipping {k}") + # If its called alpha and beta, this key will be skipped and is not error worthy + if not 'LayerNorm.weight' in k and not 'LayerNorm.bias' in k: + print(f"Bad key. Skipping {k}") return m @@ -453,6 +458,8 @@ def to_tlm_array(pytorch_tlm: nn.Module, embeddings_keys: List[str] = None, name if hasattr(pytorch_tlm.embeddings.reduction, 'ln'): d.update(to_weight_array(pytorch_tlm.embeddings.reduction.ln, name=f"{name}/Embeddings/reduction/ln")) + + return d @@ -467,6 +474,12 @@ def save_tlm_npz(pytorch_tlm: nn.Module, npz: str, embeddings_keys: List[str] = :return: None """ d = to_tlm_array(pytorch_tlm, embeddings_keys, name) + # This might not be the best way to do this, but it should work + # we dont want to put it in to_tlm_array because there are other cases where we need this to be something else + if hasattr(pytorch_tlm, 'output_layer') and hasattr(pytorch_tlm.output_layer, 'bias') and pytorch_tlm.output_layer.bias != None: + bias = pytorch_tlm.output_layer.bias.cpu().detach().numpy() + d.update({f"{name}/output/bias": bias}) + if verbose: print(d.keys()) np.savez(npz, **d) @@ -774,6 +787,11 @@ def load_tlm_npz(pytorch_tlm: nn.Module, npz: str, embeddings_keys: List[str] = d = np.load(npz) from_tlm_array(pytorch_tlm, d, embeddings_keys, name) + + if hasattr(pytorch_tlm, 'output_layer') and hasattr(pytorch_tlm.output_layer, 'bias') and pytorch_tlm.output_layer.bias != None: + device = pytorch_tlm.output_layer.bias.device + pytorch_tlm.output_layer.bias = nn.Parameter(torch.from_numpy(d[f"{name}/output/bias"]).to(device=device), requires_grad=True) + def load_tlm_output_npz(pytorch_tlm: nn.Module, npz: str, embeddings_keys: List[str] = None, name: str = "TLM"): """Restore a TLM-like model (possibly a `nn.Module` for fine-tuning