Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch inferencing support for GPT2LMHeadModel #7552

Merged
merged 3 commits into from Oct 14, 2020

Conversation

cccntu
Copy link
Contributor

@cccntu cccntu commented Oct 3, 2020

What does this PR do?

This adds correct (absolute) positional embedding to the output, when given attention mask. The positional embedding is calculated using attention mask.
Fixes #3021
Here is an example usage:

from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2', return_dict=True)

# when generating, we will use the logits of right-most token to predict the next token
# so the padding should be on the left
tokenizer.padding_side = "left" 
tokenizer.pad_token = tokenizer.eos_token # to avoid an error

sentences = ["Hello, my dog is a little",
            "Hello, my dog is", # use different length sentences to test batching
            ]
inputs = tokenizer(sentences, return_tensors="pt", padding=True)


output_sequences = model.generate(
    input_ids=inputs['input_ids'],
    attention_mask=inputs['attention_mask'],
    do_sample=False, # disable sampling to test if batching affects output
)

for i in range(len(sentences)):
    print(tokenizer.decode(output_sequences[i]))
    # you can use skip_special_tokens=True in decode() to remove padding token
    # but note that it will also remove other special_tokens

outputs:

Hello, my dog is a little bit of a mess. I'm not sure if he's going
<|endoftext|><|endoftext|>Hello, my dog is a little bit of a mess. I'm not sure if he

comment:

  • I think this should be used in examples/text-generation/run_generation.py, but I don't know much about other models, and it (code) would be weird if only gpt2 supports batch inferencing.

albert, bert, GPT2, XLM: @LysandreJik
TextGeneration: @TevenLeScao
documentation: @sgugger
@patrickvonplaten

@cccntu
Copy link
Contributor Author

cccntu commented Oct 9, 2020

This enables significantly faster generation.
Here is a simple test I ran.

generate 20 tokens generate 100 tokens
batch size = 1 45.2 s 3min 42s
batch size = 32 2.25 s (20x) 8.36 s (26.5x)
# following above code
data = sentences * 128 # total 256 sentences
model.cuda();
data = [' '.join([x]*10) for x in data] # make the prompt longer to be more realistic
from tqdm.auto import tqdm

def test(batchsize = 1, max_gen_len = 20):
    for i in tqdm(range(0, len(data), batchsize)):
        batch = data[i: i+batchsize]
        inputs = tokenizer(batch, return_tensors="pt", padding=True)

        output_sequences = model.generate(
            input_ids=inputs['input_ids'].to(model.device),
            attention_mask=inputs['attention_mask'].to(model.device),
            do_sample=False, # disable sampling to test if batching affects output
            pad_token_id=tokenizer.eos_token_id,
            max_length=len(inputs['input_ids'][0]) + max_gen_len, # let it generate longer
        )
        outputs = [tokenizer.decode(x) for x in output_sequences]


%time test(1, 20)

%time test(32, 20)

%time test(1, 100)

%time test(32, 100)

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Oct 13, 2020

Hey @cccntu - this is a great addition! I very much like your appraoch here.
I also checked that all GPT2 SLOW tests function correctly and added a test to make sure batch generation works as expected!

With the current implementation, the user would not be able to define his own position_ids for generate, since they are always overwritten in the prepare_input_ids_for_generation, but I think this is OK because:

  1. Previously, it was impossible for the user to use position_ids because they would have to be extended by 1 each generation step - a feature which is not implemented
  2. I don't see any reason why position_ids should be different from the way it is implement in the PR right now

@LysandreJik - this feature was heavily requested by the community (linked a couple of issues below) and I think this is a great way to handle GPT2 batch generation. What do you think?

@patrickvonplaten
Copy link
Contributor

Related issues: #6742, #4746,
#4824

@patrickvonplaten
Copy link
Contributor

@cccntu - Great work on this PR! If this PR is merged and you want to help the community a tiny bit more, you could give a short description (similar to what you've done above) on how to do batch generation with GPT2 here: https://discuss.huggingface.co/t/batch-generation-with-gpt2/1517. Many people have been asking for this so they would be very glad to see a short forum post about it.

Thanks a lot again!

Comment on lines +705 to +714
position_ids = kwargs.get("position_ids", None)

if attention_mask is not None and position_ids is None:
# create postion_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten
Now that you add
position_ids = kwargs.get("position_ids", None)
I think we can get rid of
else: position_ids = None

Also inspired by this related PR #7355, I think we should move all the if past together, just above return

Should I add another commit?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong opinions on this, will let @patrickvonplaten decide to merge with or without this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cccntu - yeah I thought about this as well. The problem with this and PR #7355 and passing position_ids is that we would have to incrementally add new tokens to position_ids in generate() which would be pretty hacky since not all models support position_ids => so I'd rather not do this before doing a bigger refactor of generate, see: #6949 (will continue on the bigger refactor soon).

We can always change that later without breaking backwards compatibility.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, very simple implementation! Thanks a lot @cccntu.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Oct 14, 2020

Awesome, great work @cccntu ! It would be amazing if you could write a little description of how your PR works on the forum: https://discuss.huggingface.co/t/batch-generation-with-gpt2/1517 - the community would be very thankful I think :-)

@patrickvonplaten patrickvonplaten merged commit 121dd43 into huggingface:master Oct 14, 2020
@cccntu
Copy link
Contributor Author

cccntu commented Oct 14, 2020

@patrickvonplaten Thanks for the suggestions! I just added some description to the forum post. 😃

link to the post for future reference: https://discuss.huggingface.co/t/batch-generation-with-gpt2/1517/2

@LSinev
Copy link
Contributor

LSinev commented Oct 19, 2020

Can you please add batch inferencing for GPT2DoubleHeadsModel too?

fabiocapsouza pushed a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
* Add support for gpt2 batch inferencing

* add test

* remove typo

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
fabiocapsouza added a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
@spate141
Copy link

spate141 commented Mar 1, 2021

@patrickvonplaten @cccntu

I can see how batch generation is now available. I was wondering, if there's already a way to do the same but with different arguments of max_len & min_length per encoded_text in a batch in model.generate(). Goal here is to generate new text for a batch of encoded text with variable size.

@cccntu
Copy link
Contributor Author

cccntu commented Mar 2, 2021

Hi @spate141,

Did you mean passing a max_len & min_length as n-element array?
It would fail here:

assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."

Actually, the main issue is here:
next_token_logits = outputs.logits[:, -1, :]

We need the right-most logits not be padding, and without modifying generation_utils.py, we need to use left-padding, and consequently we need this PR to make sure the positional embedding is correct.

You can also checkout the discussions in #3021, or the forum post: https://discuss.huggingface.co/t/batch-generation-with-gpt2/1517/3

@spate141
Copy link

spate141 commented Mar 2, 2021

Did you mean passing a max_len & min_length as n-element array?

  • Yes, exactly! Instead of single int values for all texts in a batch... an array of values for each text in a batch.

I saw the code and I can see why it will fail. #3021 seems informative, I'll take a look.

Meanwhile I found this way to get what I mentioned:

  • Let's assume a model accepts input of max_len = 64 and we want to generate new text for a piece of text of size 300 tokens.
  • Since we know what's the max_len is, we have make sure that we split our input text into 5 batches: [64, 60, 58, 50, 56, 12].
    • This was done in some clever way to ensure that each text segment follows valid grammar rule and also don't go above that max_len limit.
  • For all these 6 text segments we want to generate new text with following min, max values:
    • min_values: [100, 100, 100, 100, 100, 25]
    • max_values: [120, 120, 120, 120, 120, 50]
  • To do that, I can just pass a global min & max values (i.e. 100, 120 respectively) to model.generate() along with a tokenized batch of input text segments.
    • input_ids_shape: (6, 64), min_len: 100, max_len: 120
  • My only issue here is regarding last text segment in a batch of (6, 64) tokenized tensor. Ideally, we want new generated text of size min of 25 tokens and max of 50 tokens. Generating a new text of size 100 tokens from an input of 12 tokens will be gobbledygook.
  • To handle this, I can just take the last segment of generated text that belongs to our last input text; and split the text and discard everything above its ideal original min/max limit, i.e. (25, 50)

OR

  • I can just go with doing same but I combine first 5 text segments and generate text on (5, 64) and generate text for the last one (1, 64) in two pass

OR

  • I can just generate everything in 6 pass for each 6 text segments and pass their ideal individual min/max limits

@cccntu In your 2nd comment to this pull request, you posted some impressive results on why doing batch_generation is ideal, specially let's say when you have a GPU. I'm just trying to figure out if doing the same in my case is worth the latency when I have to do some post-processing. I'll post some latency results once I have this setup ready.

@spate141
Copy link

spate141 commented Mar 2, 2021

Update: @cccntu

I went with my 1st approach where I'm generating text for all texts in a single batch with global min, max values. In most cases where my last text chunk in batch is smaller meaning its min/max values are smaller than rest of text chunks in a same batch; I'm just trimming tokens. Results are impressive so far. Some numbers just in case someone stumble upon this thread in future:

Fixed size text batches:

  • This shows when passing list of text chunks as single batch tensor Vs passing text chunks as individual in for loop. max_len, min_len variables are kept same in both. Y-axis shows total time in seconds for model to finish generating text.
  • All the text chunks are of same size.

image

Variable size text batches:

  • Same as above, but here I'm using variable size text chunks.
  • For example: 2 Long, 1 Short means my input is 2 long size texts + 1 short size text. This is to test what happens when I'm generating text for variable size text chunks in a single batch.
  • Also to note that I'm trimming generated text for short text chunks in post processing. So, time on Y-axis include that.

image

Overall, batch text generation seems very useful(🎉) despite one has to add some overhead on top to manage some use cases.

@callzhang
Copy link

@cccntu Thanks for your great work! I stumbled upon this thread and would like to know:

  1. Would this batching mechanism works for GPT-NEO?
  2. Would this batching mechanism works for pipeline inference?
    If so, is there any changes or considerations I need to do or know?

@thomas-li-sjtu
Copy link

Thanks for the code! I wonder if now I could generate sentences in a batch withother models (BertGeneration, for instance)? Looking forward to your reply!

@irasin
Copy link

irasin commented Jan 19, 2023

@cccntu Thanks for your code. By using the correct position_id in this case, we can do batch inference in pytorch model now.

But when we export the gpt2 model to onnx with GPT2OnnxConfig

onnx_config = GPT2OnnxConfig(model.config)
## or using past_key_values mode
# onnx_config = GPT2OnnxConfig(model.config, use_past=True)

Then the onnx model inputs don't contation position_id but only input_ids nand attention_masks。
So we can't do correct batch_inference with onnx model now, right?

@williamLyh
Copy link

Thank you for the code. I wonder if you have tested whether there is performance drop when using batch generation? Especially when the GPT-2 model is finetuned with right-padded data.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Can GPT2LMHeadModel do batch inference with variable sentence lengths?
9 participants