-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add UniLM model #2160
Closed
Closed
[WIP] Add UniLM model #2160
Changes from 5 commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
2bd3702
merge unilm code
donglixp 1aa685b
amp save&load
donglixp 6f519d2
Merge branch 'master' into master
addf400 75d0408
Update modeling_unilm.py
addf400 28bb1ae
Delete transformers.code-workspace
addf400 ad1efc3
update get_linear_schedule_with_warmup
addf400 bbf553d
update get_linear_schedule_with_warmup
addf400 0992152
update checkpoint url & base model
addf400 da897dd
tokenizer for base model
addf400 4803777
tokenizer for base model
addf400 3a683df
Update MIT
addf400 76bfe9a
Add unilm into readme
addf400 8e2ac12
update
addf400 d946faa
Merge branch 'master' into master
addf400 49c016c
Update README.md
addf400 2227907
test for modeling & tokenizer
addf400 0430325
Update licence
addf400 58663ba
Upload model checkpoint
addf400 a97ea6f
upload model config
addf400 cce3218
Update vocab
addf400 f45ad65
Merge branch 'master' into master
addf400 f34a338
Update tokenization_auto.py
addf400 3f891dd
Update config_auto.py
addf400 03125cf
change name
addf400 8ab9bc3
fx decode
addf400 bbacc86
fx scheduler
addf400 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
"""BERT finetuning runner.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import os | ||
import logging | ||
import glob | ||
import argparse | ||
import math | ||
import random | ||
from tqdm import tqdm, trange | ||
import pickle | ||
import numpy as np | ||
import torch | ||
from torch.utils.data import DataLoader, RandomSampler | ||
from torch.utils.data.distributed import DistributedSampler | ||
|
||
from transformers import (UnilmTokenizer, WhitespaceTokenizer, | ||
UnilmForSeq2SeqDecode, AdamW, WarmupLinearSchedule, UnilmConfig) | ||
|
||
import utils_seq2seq | ||
|
||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) | ||
for conf in (UnilmConfig,)), ()) | ||
MODEL_CLASSES = { | ||
'unilm': (UnilmConfig, UnilmForSeq2SeqDecode, UnilmTokenizer) | ||
} | ||
|
||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', | ||
datefmt='%m/%d/%Y %H:%M:%S', | ||
level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def detokenize(tk_list): | ||
r_list = [] | ||
for tk in tk_list: | ||
if tk.startswith('##') and len(r_list) > 0: | ||
r_list[-1] = r_list[-1] + tk[2:] | ||
else: | ||
r_list.append(tk) | ||
return r_list | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
|
||
# Required parameters | ||
parser.add_argument("--model_type", default=None, type=str, required=True, | ||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) | ||
parser.add_argument("--model_name_or_path", default=None, type=str, required=True, | ||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) | ||
parser.add_argument("--model_recover_path", default=None, type=str, | ||
help="The file of fine-tuned pretraining model.") | ||
parser.add_argument("--config_name", default="", type=str, | ||
help="Pretrained config name or path if not the same as model_name") | ||
parser.add_argument("--tokenizer_name", default="", type=str, | ||
help="Pretrained tokenizer name or path if not the same as model_name") | ||
parser.add_argument("--max_seq_length", default=512, type=int, | ||
help="The maximum total input sequence length after WordPiece tokenization. \n" | ||
"Sequences longer than this will be truncated, and sequences shorter \n" | ||
"than this will be padded.") | ||
|
||
# decoding parameters | ||
parser.add_argument('--fp16', action='store_true', | ||
help="Whether to use 16-bit float precision instead of 32-bit") | ||
parser.add_argument('--fp16_opt_level', type=str, default='O1', | ||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." | ||
"See details at https://nvidia.github.io/apex/amp.html") | ||
parser.add_argument("--input_file", type=str, help="Input file") | ||
parser.add_argument('--subset', type=int, default=0, | ||
help="Decode a subset of the input dataset.") | ||
parser.add_argument("--output_file", type=str, help="output file") | ||
parser.add_argument("--split", type=str, default="", | ||
help="Data split (train/val/test).") | ||
parser.add_argument('--tokenized_input', action='store_true', | ||
help="Whether the input is tokenized.") | ||
parser.add_argument('--seed', type=int, default=123, | ||
help="random seed for initialization") | ||
parser.add_argument("--do_lower_case", action='store_true', | ||
help="Set this flag if you are using an uncased model.") | ||
parser.add_argument('--batch_size', type=int, default=4, | ||
help="Batch size for decoding.") | ||
parser.add_argument('--beam_size', type=int, default=1, | ||
help="Beam size for searching") | ||
parser.add_argument('--length_penalty', type=float, default=0, | ||
help="Length penalty for beam search") | ||
parser.add_argument('--forbid_duplicate_ngrams', action='store_true') | ||
parser.add_argument('--forbid_ignore_word', type=str, default=None, | ||
help="Forbid the word during forbid_duplicate_ngrams") | ||
parser.add_argument("--min_len", default=None, type=int) | ||
parser.add_argument('--need_score_traces', action='store_true') | ||
parser.add_argument('--ngram_size', type=int, default=3) | ||
parser.add_argument('--max_tgt_length', type=int, default=128, | ||
help="maximum length of target sequence") | ||
|
||
args = parser.parse_args() | ||
|
||
if args.need_score_traces and args.beam_size <= 1: | ||
raise ValueError( | ||
"Score trace is only available for beam search with beam size > 1.") | ||
if args.max_tgt_length >= args.max_seq_length - 2: | ||
raise ValueError("Maximum tgt length exceeds max seq length - 2.") | ||
|
||
device = torch.device( | ||
"cuda" if torch.cuda.is_available() else "cpu") | ||
n_gpu = torch.cuda.device_count() | ||
|
||
random.seed(args.seed) | ||
np.random.seed(args.seed) | ||
torch.manual_seed(args.seed) | ||
if n_gpu > 0: | ||
torch.cuda.manual_seed_all(args.seed) | ||
|
||
args.model_type = args.model_type.lower() | ||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] | ||
config = config_class.from_pretrained( | ||
args.config_name if args.config_name else args.model_name_or_path, max_position_embeddings=args.max_seq_length) | ||
tokenizer = tokenizer_class.from_pretrained( | ||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case) | ||
|
||
bi_uni_pipeline = [] | ||
bi_uni_pipeline.append(utils_seq2seq.Preprocess4Seq2seqDecode(list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids, | ||
args.max_seq_length, max_tgt_length=args.max_tgt_length)) | ||
|
||
# Prepare model | ||
mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids( | ||
["[MASK]", "[SEP]", "[S2S_SOS]"]) | ||
forbid_ignore_set = None | ||
if args.forbid_ignore_word: | ||
w_list = [] | ||
for w in args.forbid_ignore_word.split('|'): | ||
if w.startswith('[') and w.endswith(']'): | ||
w_list.append(w.upper()) | ||
else: | ||
w_list.append(w) | ||
forbid_ignore_set = set(tokenizer.convert_tokens_to_ids(w_list)) | ||
print(args.model_recover_path) | ||
for model_recover_path in glob.glob(args.model_recover_path.strip()): | ||
logger.info("***** Recover model: %s *****", model_recover_path) | ||
model_recover = torch.load(model_recover_path) | ||
model = model_class.from_pretrained(args.model_name_or_path, state_dict=model_recover, config=config, mask_word_id=mask_word_id, search_beam_size=args.beam_size, length_penalty=args.length_penalty, | ||
eos_id=eos_word_ids, sos_id=sos_word_id, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len) | ||
del model_recover | ||
|
||
model.to(device) | ||
|
||
if args.fp16: | ||
try: | ||
from apex import amp | ||
except ImportError: | ||
raise ImportError( | ||
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") | ||
model = amp.initialize(model, opt_level=args.fp16_opt_level) | ||
|
||
if n_gpu > 1: | ||
model = torch.nn.DataParallel(model) | ||
|
||
torch.cuda.empty_cache() | ||
model.eval() | ||
next_i = 0 | ||
max_src_length = args.max_seq_length - 2 - args.max_tgt_length | ||
|
||
with open(args.input_file, encoding="utf-8") as fin: | ||
input_lines = [x.strip() for x in fin.readlines()] | ||
if args.subset > 0: | ||
logger.info("Decoding subset: %d", args.subset) | ||
input_lines = input_lines[:args.subset] | ||
data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer | ||
input_lines = [data_tokenizer.tokenize( | ||
x)[:max_src_length] for x in input_lines] | ||
input_lines = sorted(list(enumerate(input_lines)), | ||
key=lambda x: -len(x[1])) | ||
output_lines = [""] * len(input_lines) | ||
score_trace_list = [None] * len(input_lines) | ||
total_batch = math.ceil(len(input_lines) / args.batch_size) | ||
|
||
with tqdm(total=total_batch) as pbar: | ||
while next_i < len(input_lines): | ||
_chunk = input_lines[next_i:next_i + args.batch_size] | ||
buf_id = [x[0] for x in _chunk] | ||
buf = [x[1] for x in _chunk] | ||
next_i += args.batch_size | ||
max_a_len = max([len(x) for x in buf]) | ||
instances = [] | ||
for instance in [(x, max_a_len) for x in buf]: | ||
for proc in bi_uni_pipeline: | ||
instances.append(proc(instance)) | ||
with torch.no_grad(): | ||
batch = utils_seq2seq.batch_list_to_batch_tensors( | ||
instances) | ||
batch = [ | ||
t.to(device) if t is not None else None for t in batch] | ||
input_ids, token_type_ids, position_ids, input_mask = batch | ||
traces = model(input_ids, token_type_ids, | ||
position_ids, input_mask) | ||
if args.beam_size > 1: | ||
traces = {k: v.tolist() for k, v in traces.items()} | ||
output_ids = traces['pred_seq'] | ||
else: | ||
output_ids = traces.tolist() | ||
for i in range(len(buf)): | ||
w_ids = output_ids[i] | ||
output_buf = tokenizer.convert_ids_to_tokens(w_ids) | ||
output_tokens = [] | ||
for t in output_buf: | ||
if t in ("[SEP]", "[PAD]"): | ||
break | ||
output_tokens.append(t) | ||
output_sequence = ' '.join(detokenize(output_tokens)) | ||
output_lines[buf_id[i]] = output_sequence | ||
if args.need_score_traces: | ||
score_trace_list[buf_id[i]] = { | ||
'scores': traces['scores'][i], 'wids': traces['wids'][i], 'ptrs': traces['ptrs'][i]} | ||
pbar.update(1) | ||
if args.output_file: | ||
fn_out = args.output_file | ||
else: | ||
fn_out = model_recover_path+'.'+args.split | ||
with open(fn_out, "w", encoding="utf-8") as fout: | ||
for l in output_lines: | ||
fout.write(l) | ||
fout.write("\n") | ||
|
||
if args.need_score_traces: | ||
with open(fn_out + ".trace.pickle", "wb") as fout_trace: | ||
pickle.dump( | ||
{"version": 0.0, "num_samples": len(input_lines)}, fout_trace) | ||
for x in score_trace_list: | ||
pickle.dump(x, fout_trace) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import pickle | ||
import math | ||
import argparse | ||
import glob | ||
import unicodedata | ||
from pathlib import Path | ||
from tqdm import tqdm | ||
|
||
from transformers import UnilmTokenizer | ||
|
||
|
||
def read_traces_from_file(file_name): | ||
with open(file_name, "rb") as fin: | ||
meta = pickle.load(fin) | ||
num_samples = meta["num_samples"] | ||
samples = [] | ||
for _ in range(num_samples): | ||
samples.append(pickle.load(fin)) | ||
return samples | ||
|
||
|
||
def get_best_sequence(sample, eos_id, pad_id, length_penalty=None, alpha=None, expect=None, min_len=None): | ||
# if not any((length_penalty, alpha, expect, min_len)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. delete or leave in |
||
# raise ValueError( | ||
# "You can only specify length penalty or alpha, but not both.") | ||
scores = sample["scores"] | ||
wids_list = sample["wids"] | ||
ptrs = sample["ptrs"] | ||
|
||
last_frame_id = len(scores) - 1 | ||
for i, wids in enumerate(wids_list): | ||
if all(wid in (eos_id, pad_id) for wid in wids): | ||
last_frame_id = i | ||
break | ||
while all(wid == pad_id for wid in wids_list[last_frame_id]): | ||
last_frame_id -= 1 | ||
|
||
max_score = -math.inf | ||
frame_id = -1 | ||
pos_in_frame = -1 | ||
|
||
for fid in range(last_frame_id + 1): | ||
for i, wid in enumerate(wids_list[fid]): | ||
if fid <= last_frame_id and scores[fid][i] >= 0: | ||
# skip paddings | ||
continue | ||
if (wid in (eos_id, pad_id)) or fid == last_frame_id: | ||
s = scores[fid][i] | ||
if length_penalty: | ||
if expect: | ||
s -= length_penalty * math.fabs(fid+1 - expect) | ||
else: | ||
s += length_penalty * (fid + 1) | ||
elif alpha: | ||
s = s / math.pow((5 + fid + 1) / 6.0, alpha) | ||
if s > max_score: | ||
# if (frame_id != -1) and min_len and (fid+1 < min_len): | ||
# continue | ||
max_score = s | ||
frame_id = fid | ||
pos_in_frame = i | ||
if frame_id == -1: | ||
seq = [] | ||
else: | ||
seq = [wids_list[frame_id][pos_in_frame]] | ||
for fid in range(frame_id, 0, -1): | ||
pos_in_frame = ptrs[fid][pos_in_frame] | ||
seq.append(wids_list[fid - 1][pos_in_frame]) | ||
seq.reverse() | ||
return seq | ||
|
||
|
||
def detokenize(tk_list): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. defined elsewhere, pls use that one. |
||
r_list = [] | ||
for tk in tk_list: | ||
if tk.startswith('##') and len(r_list) > 0: | ||
r_list[-1] = r_list[-1] + tk[2:] | ||
else: | ||
r_list.append(tk) | ||
return r_list | ||
|
||
|
||
def simple_postprocess(tk_list): | ||
# truncate duplicate punctuations | ||
while tk_list and len(tk_list) > 4 and len(tk_list[-1]) == 1 and unicodedata.category(tk_list[-1]).startswith('P') and all(it == tk_list[-1] for it in tk_list[-4:]): | ||
tk_list = tk_list[:-3] | ||
return tk_list | ||
|
||
|
||
def main(args): | ||
tokenizer = UnilmTokenizer.from_pretrained( | ||
args.bert_model, do_lower_case=args.do_lower_case) | ||
|
||
eos_id, pad_id = set(tokenizer.convert_tokens_to_ids(["[SEP]", "[PAD]"])) | ||
for input_file in tqdm(glob.glob(args.input)): | ||
if not Path(input_file+'.trace.pickle').exists(): | ||
continue | ||
print(input_file) | ||
samples = read_traces_from_file(input_file+'.trace.pickle') | ||
|
||
results = [] | ||
|
||
for s in samples: | ||
word_ids = get_best_sequence(s, eos_id, pad_id, alpha=args.alpha, | ||
length_penalty=args.length_penalty, expect=args.expect, min_len=args.min_len) | ||
tokens = tokenizer.convert_ids_to_tokens(word_ids) | ||
buf = [] | ||
for t in tokens: | ||
if t in ("[SEP]", "[PAD]"): | ||
break | ||
else: | ||
buf.append(t) | ||
results.append(" ".join(simple_postprocess(detokenize(buf)))) | ||
|
||
fn_out = input_file+'.' | ||
if args.length_penalty: | ||
fn_out += 'lenp'+str(args.length_penalty) | ||
if args.expect: | ||
fn_out += 'exp'+str(args.expect) | ||
if args.alpha: | ||
fn_out += 'alp'+str(args.alpha) | ||
if args.min_len: | ||
fn_out += 'minl'+str(args.min_len) | ||
with open(fn_out, "w", encoding="utf-8") as fout: | ||
for line in results: | ||
fout.write(line) | ||
fout.write("\n") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--input", type=str, help="Input file.") | ||
parser.add_argument("--bert_model", default=None, type=str, required=True, | ||
help="Bert pre-trained model selected in the list: bert-base-uncased, " | ||
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") | ||
parser.add_argument("--alpha", default=None, type=float) | ||
parser.add_argument("--length_penalty", default=None, type=float) | ||
parser.add_argument("--expect", default=None, type=float, | ||
help="Expectation of target length.") | ||
parser.add_argument("--min_len", default=None, type=int) | ||
parser.add_argument("--do_lower_case", action='store_true', | ||
help="Set this flag if you are using an uncased model.") | ||
args = parser.parse_args() | ||
|
||
main(args) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) I prefer token_list as the name, didnt know what tk was.