Skip to content

Commit 121dd43

Browse files
Add batch inferencing support for GPT2LMHeadModel (#7552)
* Add support for gpt2 batch inferencing * add test * remove typo Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
1 parent 0c64b18 commit 121dd43

File tree

2 files changed

+58
-0
lines changed

2 files changed

+58
-0
lines changed

src/transformers/modeling_gpt2.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,10 +701,23 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
701701
if past:
702702
input_ids = input_ids[:, -1].unsqueeze(-1)
703703

704+
attention_mask = kwargs.get("attention_mask", None)
705+
position_ids = kwargs.get("position_ids", None)
706+
707+
if attention_mask is not None and position_ids is None:
708+
# create postion_ids on the fly for batch generation
709+
position_ids = attention_mask.long().cumsum(-1) - 1
710+
position_ids.masked_fill_(attention_mask == 0, 1)
711+
if past:
712+
position_ids = position_ids[:, -1].unsqueeze(-1)
713+
else:
714+
position_ids = None
704715
return {
705716
"input_ids": input_ids,
706717
"past_key_values": past,
707718
"use_cache": kwargs.get("use_cache"),
719+
"position_ids": position_ids,
720+
"attention_mask": attention_mask,
708721
}
709722

710723
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)

tests/test_modeling_gpt2.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
GPT2ForSequenceClassification,
3434
GPT2LMHeadModel,
3535
GPT2Model,
36+
GPT2Tokenizer,
3637
)
3738

3839

@@ -425,6 +426,50 @@ def test_gpt2_gradient_checkpointing(self):
425426
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
426427
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)
427428

429+
@slow
430+
def test_batch_generation(self):
431+
model = GPT2LMHeadModel.from_pretrained("gpt2")
432+
model.to(torch_device)
433+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
434+
435+
tokenizer.padding_side = "left"
436+
437+
# Define PAD Token = EOS Token = 50256
438+
tokenizer.pad_token = tokenizer.eos_token
439+
model.config.pad_token_id = model.config.eos_token_id
440+
441+
# use different length sentences to test batching
442+
sentences = [
443+
"Hello, my dog is a little",
444+
"Today, I",
445+
]
446+
447+
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
448+
449+
torch.manual_seed(0)
450+
outputs = model.generate(
451+
input_ids=inputs["input_ids"].to(torch_device),
452+
attention_mask=inputs["attention_mask"].to(torch_device),
453+
)
454+
455+
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
456+
output_non_padded = model.generate(input_ids=inputs_non_padded)
457+
458+
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
459+
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
460+
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
461+
462+
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
463+
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
464+
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
465+
466+
expected_output_sentence = [
467+
"Hello, my dog is a little bit of a mess. I'm not sure if he's going",
468+
"Today, I'm going to be doing a lot of research on this. I",
469+
]
470+
self.assertListEqual(expected_output_sentence, batch_out_sentence)
471+
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
472+
428473
@slow
429474
def test_model_from_pretrained(self):
430475
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:

0 commit comments

Comments
 (0)