Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add UniLM model #2160

Closed
wants to merge 26 commits into from
Closed
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 236 additions & 0 deletions examples/decode_seq2seq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
"""BERT finetuning runner."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import logging
import glob
import argparse
import math
import random
from tqdm import tqdm, trange
import pickle
import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler

from transformers import (UnilmTokenizer, WhitespaceTokenizer,
UnilmForSeq2SeqDecode, AdamW, WarmupLinearSchedule, UnilmConfig)

import utils_seq2seq

ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys())
for conf in (UnilmConfig,)), ())
MODEL_CLASSES = {
'unilm': (UnilmConfig, UnilmForSeq2SeqDecode, UnilmTokenizer)
}

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)


def detokenize(tk_list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) I prefer token_list as the name, didnt know what tk was.

r_list = []
for tk in tk_list:
if tk.startswith('##') and len(r_list) > 0:
r_list[-1] = r_list[-1] + tk[2:]
else:
r_list.append(tk)
return r_list


def main():
parser = argparse.ArgumentParser()

# Required parameters
parser.add_argument("--model_type", default=None, type=str, required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--model_recover_path", default=None, type=str,
help="The file of fine-tuned pretraining model.")
parser.add_argument("--config_name", default="", type=str,
help="Pretrained config name or path if not the same as model_name")
parser.add_argument("--tokenizer_name", default="", type=str,
help="Pretrained tokenizer name or path if not the same as model_name")
parser.add_argument("--max_seq_length", default=512, type=int,
help="The maximum total input sequence length after WordPiece tokenization. \n"
"Sequences longer than this will be truncated, and sequences shorter \n"
"than this will be padded.")

# decoding parameters
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit float precision instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O1',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument("--input_file", type=str, help="Input file")
parser.add_argument('--subset', type=int, default=0,
help="Decode a subset of the input dataset.")
parser.add_argument("--output_file", type=str, help="output file")
parser.add_argument("--split", type=str, default="",
help="Data split (train/val/test).")
parser.add_argument('--tokenized_input', action='store_true',
help="Whether the input is tokenized.")
parser.add_argument('--seed', type=int, default=123,
help="random seed for initialization")
parser.add_argument("--do_lower_case", action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument('--batch_size', type=int, default=4,
help="Batch size for decoding.")
parser.add_argument('--beam_size', type=int, default=1,
help="Beam size for searching")
parser.add_argument('--length_penalty', type=float, default=0,
help="Length penalty for beam search")
parser.add_argument('--forbid_duplicate_ngrams', action='store_true')
parser.add_argument('--forbid_ignore_word', type=str, default=None,
help="Forbid the word during forbid_duplicate_ngrams")
parser.add_argument("--min_len", default=None, type=int)
parser.add_argument('--need_score_traces', action='store_true')
parser.add_argument('--ngram_size', type=int, default=3)
parser.add_argument('--max_tgt_length', type=int, default=128,
help="maximum length of target sequence")

args = parser.parse_args()

if args.need_score_traces and args.beam_size <= 1:
raise ValueError(
"Score trace is only available for beam search with beam size > 1.")
if args.max_tgt_length >= args.max_seq_length - 2:
raise ValueError("Maximum tgt length exceeds max seq length - 2.")

device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)

args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(
args.config_name if args.config_name else args.model_name_or_path, max_position_embeddings=args.max_seq_length)
tokenizer = tokenizer_class.from_pretrained(
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case)

bi_uni_pipeline = []
bi_uni_pipeline.append(utils_seq2seq.Preprocess4Seq2seqDecode(list(tokenizer.vocab.keys()), tokenizer.convert_tokens_to_ids,
args.max_seq_length, max_tgt_length=args.max_tgt_length))

# Prepare model
mask_word_id, eos_word_ids, sos_word_id = tokenizer.convert_tokens_to_ids(
["[MASK]", "[SEP]", "[S2S_SOS]"])
forbid_ignore_set = None
if args.forbid_ignore_word:
w_list = []
for w in args.forbid_ignore_word.split('|'):
if w.startswith('[') and w.endswith(']'):
w_list.append(w.upper())
else:
w_list.append(w)
forbid_ignore_set = set(tokenizer.convert_tokens_to_ids(w_list))
print(args.model_recover_path)
for model_recover_path in glob.glob(args.model_recover_path.strip()):
logger.info("***** Recover model: %s *****", model_recover_path)
model_recover = torch.load(model_recover_path)
model = model_class.from_pretrained(args.model_name_or_path, state_dict=model_recover, config=config, mask_word_id=mask_word_id, search_beam_size=args.beam_size, length_penalty=args.length_penalty,
eos_id=eos_word_ids, sos_id=sos_word_id, forbid_duplicate_ngrams=args.forbid_duplicate_ngrams, forbid_ignore_set=forbid_ignore_set, ngram_size=args.ngram_size, min_len=args.min_len)
del model_recover

model.to(device)

if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model = amp.initialize(model, opt_level=args.fp16_opt_level)

if n_gpu > 1:
model = torch.nn.DataParallel(model)

torch.cuda.empty_cache()
model.eval()
next_i = 0
max_src_length = args.max_seq_length - 2 - args.max_tgt_length

with open(args.input_file, encoding="utf-8") as fin:
input_lines = [x.strip() for x in fin.readlines()]
if args.subset > 0:
logger.info("Decoding subset: %d", args.subset)
input_lines = input_lines[:args.subset]
data_tokenizer = WhitespaceTokenizer() if args.tokenized_input else tokenizer
input_lines = [data_tokenizer.tokenize(
x)[:max_src_length] for x in input_lines]
input_lines = sorted(list(enumerate(input_lines)),
key=lambda x: -len(x[1]))
output_lines = [""] * len(input_lines)
score_trace_list = [None] * len(input_lines)
total_batch = math.ceil(len(input_lines) / args.batch_size)

with tqdm(total=total_batch) as pbar:
while next_i < len(input_lines):
_chunk = input_lines[next_i:next_i + args.batch_size]
buf_id = [x[0] for x in _chunk]
buf = [x[1] for x in _chunk]
next_i += args.batch_size
max_a_len = max([len(x) for x in buf])
instances = []
for instance in [(x, max_a_len) for x in buf]:
for proc in bi_uni_pipeline:
instances.append(proc(instance))
with torch.no_grad():
batch = utils_seq2seq.batch_list_to_batch_tensors(
instances)
batch = [
t.to(device) if t is not None else None for t in batch]
input_ids, token_type_ids, position_ids, input_mask = batch
traces = model(input_ids, token_type_ids,
position_ids, input_mask)
if args.beam_size > 1:
traces = {k: v.tolist() for k, v in traces.items()}
output_ids = traces['pred_seq']
else:
output_ids = traces.tolist()
for i in range(len(buf)):
w_ids = output_ids[i]
output_buf = tokenizer.convert_ids_to_tokens(w_ids)
output_tokens = []
for t in output_buf:
if t in ("[SEP]", "[PAD]"):
break
output_tokens.append(t)
output_sequence = ' '.join(detokenize(output_tokens))
output_lines[buf_id[i]] = output_sequence
if args.need_score_traces:
score_trace_list[buf_id[i]] = {
'scores': traces['scores'][i], 'wids': traces['wids'][i], 'ptrs': traces['ptrs'][i]}
pbar.update(1)
if args.output_file:
fn_out = args.output_file
else:
fn_out = model_recover_path+'.'+args.split
with open(fn_out, "w", encoding="utf-8") as fout:
for l in output_lines:
fout.write(l)
fout.write("\n")

if args.need_score_traces:
with open(fn_out + ".trace.pickle", "wb") as fout_trace:
pickle.dump(
{"version": 0.0, "num_samples": len(input_lines)}, fout_trace)
for x in score_trace_list:
pickle.dump(x, fout_trace)


if __name__ == "__main__":
main()
145 changes: 145 additions & 0 deletions examples/gen_seq_from_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import pickle
import math
import argparse
import glob
import unicodedata
from pathlib import Path
from tqdm import tqdm

from transformers import UnilmTokenizer


def read_traces_from_file(file_name):
with open(file_name, "rb") as fin:
meta = pickle.load(fin)
num_samples = meta["num_samples"]
samples = []
for _ in range(num_samples):
samples.append(pickle.load(fin))
return samples


def get_best_sequence(sample, eos_id, pad_id, length_penalty=None, alpha=None, expect=None, min_len=None):
# if not any((length_penalty, alpha, expect, min_len)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete or leave in

# raise ValueError(
# "You can only specify length penalty or alpha, but not both.")
scores = sample["scores"]
wids_list = sample["wids"]
ptrs = sample["ptrs"]

last_frame_id = len(scores) - 1
for i, wids in enumerate(wids_list):
if all(wid in (eos_id, pad_id) for wid in wids):
last_frame_id = i
break
while all(wid == pad_id for wid in wids_list[last_frame_id]):
last_frame_id -= 1

max_score = -math.inf
frame_id = -1
pos_in_frame = -1

for fid in range(last_frame_id + 1):
for i, wid in enumerate(wids_list[fid]):
if fid <= last_frame_id and scores[fid][i] >= 0:
# skip paddings
continue
if (wid in (eos_id, pad_id)) or fid == last_frame_id:
s = scores[fid][i]
if length_penalty:
if expect:
s -= length_penalty * math.fabs(fid+1 - expect)
else:
s += length_penalty * (fid + 1)
elif alpha:
s = s / math.pow((5 + fid + 1) / 6.0, alpha)
if s > max_score:
# if (frame_id != -1) and min_len and (fid+1 < min_len):
# continue
max_score = s
frame_id = fid
pos_in_frame = i
if frame_id == -1:
seq = []
else:
seq = [wids_list[frame_id][pos_in_frame]]
for fid in range(frame_id, 0, -1):
pos_in_frame = ptrs[fid][pos_in_frame]
seq.append(wids_list[fid - 1][pos_in_frame])
seq.reverse()
return seq


def detokenize(tk_list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

defined elsewhere, pls use that one.

r_list = []
for tk in tk_list:
if tk.startswith('##') and len(r_list) > 0:
r_list[-1] = r_list[-1] + tk[2:]
else:
r_list.append(tk)
return r_list


def simple_postprocess(tk_list):
# truncate duplicate punctuations
while tk_list and len(tk_list) > 4 and len(tk_list[-1]) == 1 and unicodedata.category(tk_list[-1]).startswith('P') and all(it == tk_list[-1] for it in tk_list[-4:]):
tk_list = tk_list[:-3]
return tk_list


def main(args):
tokenizer = UnilmTokenizer.from_pretrained(
args.bert_model, do_lower_case=args.do_lower_case)

eos_id, pad_id = set(tokenizer.convert_tokens_to_ids(["[SEP]", "[PAD]"]))
for input_file in tqdm(glob.glob(args.input)):
if not Path(input_file+'.trace.pickle').exists():
continue
print(input_file)
samples = read_traces_from_file(input_file+'.trace.pickle')

results = []

for s in samples:
word_ids = get_best_sequence(s, eos_id, pad_id, alpha=args.alpha,
length_penalty=args.length_penalty, expect=args.expect, min_len=args.min_len)
tokens = tokenizer.convert_ids_to_tokens(word_ids)
buf = []
for t in tokens:
if t in ("[SEP]", "[PAD]"):
break
else:
buf.append(t)
results.append(" ".join(simple_postprocess(detokenize(buf))))

fn_out = input_file+'.'
if args.length_penalty:
fn_out += 'lenp'+str(args.length_penalty)
if args.expect:
fn_out += 'exp'+str(args.expect)
if args.alpha:
fn_out += 'alp'+str(args.alpha)
if args.min_len:
fn_out += 'minl'+str(args.min_len)
with open(fn_out, "w", encoding="utf-8") as fout:
for line in results:
fout.write(line)
fout.write("\n")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input", type=str, help="Input file.")
parser.add_argument("--bert_model", default=None, type=str, required=True,
help="Bert pre-trained model selected in the list: bert-base-uncased, "
"bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
parser.add_argument("--alpha", default=None, type=float)
parser.add_argument("--length_penalty", default=None, type=float)
parser.add_argument("--expect", default=None, type=float,
help="Expectation of target length.")
parser.add_argument("--min_len", default=None, type=int)
parser.add_argument("--do_lower_case", action='store_true',
help="Set this flag if you are using an uncased model.")
args = parser.parse_args()

main(args)
Loading