Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch inferencing support for GPT2LMHeadModel #7552

Merged
merged 3 commits into from Oct 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/transformers/modeling_gpt2.py
Expand Up @@ -701,10 +701,23 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)

attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)

if attention_mask is not None and position_ids is None:
# create postion_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
Comment on lines +705 to +714
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten
Now that you add
position_ids = kwargs.get("position_ids", None)
I think we can get rid of
else: position_ids = None

Also inspired by this related PR #7355, I think we should move all the if past together, just above return

Should I add another commit?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong opinions on this, will let @patrickvonplaten decide to merge with or without this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cccntu - yeah I thought about this as well. The problem with this and PR #7355 and passing position_ids is that we would have to incrementally add new tokens to position_ids in generate() which would be pretty hacky since not all models support position_ids => so I'd rather not do this before doing a bigger refactor of generate, see: #6949 (will continue on the bigger refactor soon).

We can always change that later without breaking backwards compatibility.

return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
}

@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
Expand Down
45 changes: 45 additions & 0 deletions tests/test_modeling_gpt2.py
Expand Up @@ -32,6 +32,7 @@
GPT2DoubleHeadsModel,
GPT2LMHeadModel,
GPT2Model,
GPT2Tokenizer,
)


Expand Down Expand Up @@ -405,6 +406,50 @@ def test_gpt2_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)

@slow
def test_batch_generation(self):
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

tokenizer.padding_side = "left"

# Define PAD Token = EOS Token = 50256
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

# use different length sentences to test batching
sentences = [
"Hello, my dog is a little",
"Today, I",
]

inputs = tokenizer(sentences, return_tensors="pt", padding=True)

torch.manual_seed(0)
outputs = model.generate(
input_ids=inputs["input_ids"].to(torch_device),
attention_mask=inputs["attention_mask"].to(torch_device),
)

inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
output_non_padded = model.generate(input_ids=inputs_non_padded)

num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)

batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)

expected_output_sentence = [
"Hello, my dog is a little bit of a mess. I'm not sure if he's going",
"Today, I'm going to be doing a lot of research on this. I",
]
self.assertListEqual(expected_output_sentence, batch_out_sentence)
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])

@slow
def test_model_from_pretrained(self):
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
Expand Down