Skip to content

Commit

Permalink
seq2seq/run_eval.py can take decoder_start_token_id (#5949)
Browse files Browse the repository at this point in the history
  • Loading branch information
sshleifer committed Jul 21, 2020
1 parent 5b193b3 commit 9dab39f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 3 deletions.
1 change: 1 addition & 0 deletions examples/seq2seq/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def __init__(self, hparams, **kwargs):
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
self.model.config.decoder_start_token_id = self.decoder_start_token_id
if isinstance(self.tokenizer, MBartTokenizer):
self.dataset_class = MBartDataset

Expand Down
18 changes: 17 additions & 1 deletion examples/seq2seq/run_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@ def generate_summaries_or_translations(
device: str = DEFAULT_DEVICE,
fp16=False,
task="summarization",
decoder_start_token_id=None,
**gen_kwargs,
) -> None:
fout = Path(out_file).open("w", encoding="utf-8")
model_name = str(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
if fp16:
model = model.half()
if decoder_start_token_id is None:
decoder_start_token_id = gen_kwargs.pop("decoder_start_token_id", None)

tokenizer = AutoTokenizer.from_pretrained(model_name)

Expand All @@ -48,7 +51,12 @@ def generate_summaries_or_translations(
batch = [model.config.prefix + text for text in batch]
batch = tokenizer(batch, return_tensors="pt", truncation=True, padding="max_length").to(device)
input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id)
summaries = model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_kwargs)
summaries = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_start_token_id=decoder_start_token_id,
**gen_kwargs,
)
dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for hypothesis in dec:
fout.write(hypothesis + "\n")
Expand All @@ -66,6 +74,13 @@ def run_generate():
parser.add_argument("--device", type=str, required=False, default=DEFAULT_DEVICE, help="cuda, cuda:1, cpu etc.")
parser.add_argument("--task", type=str, default="summarization", help="typically translation or summarization")
parser.add_argument("--bs", type=int, default=8, required=False, help="batch size")
parser.add_argument(
"--decoder_start_token_id",
type=int,
default=None,
required=False,
help="decoder_start_token_id (otherwise will look at config)",
)
parser.add_argument(
"--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all."
)
Expand All @@ -83,6 +98,7 @@ def run_generate():
device=args.device,
fp16=args.fp16,
task=args.task,
decoder_start_token_id=args.decoder_start_token_id,
)
if args.reference_path is None:
return
Expand Down
19 changes: 17 additions & 2 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2255,8 +2255,23 @@ def _pad(

return encoded_inputs

def batch_decode(self, sequences: List[List[int]], **kwargs) -> List[str]:
return [self.decode(seq, **kwargs) for seq in sequences]
def batch_decode(
self, sequences: List[List[int]], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
) -> List[str]:
"""
Convert a list of lists of token ids into a list of strings by calling decode.
Args:
token_ids: list of tokenized input ids. Can be obtained using the `encode` or `encode_plus` methods.
skip_special_tokens: if set to True, will replace special tokens.
clean_up_tokenization_spaces: if set to True, will clean up the tokenization spaces.
"""
return [
self.decode(
seq, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces
)
for seq in sequences
]

def decode(
self, token_ids: List[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True
Expand Down

0 comments on commit 9dab39f

Please sign in to comment.