Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding batch_size support for (almost) all pipelines #13724

Merged
merged 31 commits into from
Oct 29, 2021

Conversation

Narsil
Copy link
Contributor

@Narsil Narsil commented Sep 24, 2021

What does this PR do?

When running pipeline on a dataset, with a small model (relative to the GPU). It can be good to be able to batch
the forward pass for performance.

This PR addresses this by adding batch_size argument.

This PR contains

  • Some facilities for batching and unbatching, not handled within each individual pipelines
  • Automated testing for ALL small models + pipelines of this functionality
  • Disabled for question-answering and zero-shot-classification. They are trickier because they already use batching with candidate labels and question features. The full solution would involve moving the iterator to real N [hypothesis, template] and batching there, and having another iterator on top that recreates the current zero-shot/question-answering results. Should we add that capabilities, at least for these 2 pipelines we would have a much better idea of alignement.
  • Ran all slow (pipelines) tests without issue
  • Refactor the batch/unbatch for better quality code
  • More doc, caveats about this argument and use cases, benchmarks and so on.
  • Need to think about TF which has currently no support (both streaming and batching)

The good example (https://gist.github.com/Narsil/4e1c36d7cf8477e5c1d580585860810e):

This code was executed on GTX 970 (and Titan RTX with similar conclusions), model is distilbert-base-uncased-finetuned-sst-2-english (250Mo bin file)

The old pipelines GPU method of iteration is excluded because it's an order of magnitude slower in all cases.

------------------------------
Streaming no batching
100%|██████████████████████████████████████████████████████████████████████| 5000/5000 [00:26<00:00, 187.52it/s]
------------------------------
Streaming batch_size=8
100%|█████████████████████████████████████████████████████████████████████| 5000/5000 [00:04<00:00, 1205.95it/s]
------------------------------
Streaming batch_size=64
100%|█████████████████████████████████████████████████████████████████████| 5000/5000 [00:02<00:00, 2478.24it/s]
------------------------------
Streaming batch_size=256
100%|█████████████████████████████████████████████████████████████████████| 5000/5000 [00:01<00:00, 2554.43it/s]
(diminishing returns)

This seems promising !

However, this has:

  • Perfect alignment (all inputs are exactly the same length)
  • Small model (lots of GPU RAM left for inputs and intermediary results)

Let's look at another example, which might (or not) be a bit more realistic:
Using varying size inputs (https://gist.github.com/Narsil/de88b2d7c242c29772a61af56a5c8270)

------------------------------
Streaming no batching
100%|█████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:30<00:00, 32.51it/s]
------------------------------
Streaming batch_size=8
100%|█████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:29<00:00, 33.62it/s]
------------------------------
Streaming batch_size=64
100%|█████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:29<00:00, 34.29it/s]
------------------------------
Streaming batch_size=256
  0%|                                                                                                                                          | 0/1000 [00:01<?, ?it/s]
Traceback (most recent call last):
  File "/home/nicolas/src/transformers/test.py", line 38, in <module>
    for out in tqdm.tqdm(pipe(dataset, batch_size=256), total=len(dataset)):
  File "/home/nicolas/src/transformers/.venv/lib/python3.9/site-packages/tqdm/std.py", line 1133, in __iter__
    for obj in iterable:
....
    hidden_states = self.intermediate_act_fn(hidden_states)
  File "/home/nicolas/src/transformers/.venv/lib/python3.9/site-packages/torch/nn/functional.py", line 1555, in gelu
    return torch._C._nn.gelu(input)
RuntimeError: CUDA out of memory. Tried to allocate 472.00 MiB (GPU 0; 3.95 GiB total capacity; 2.13 GiB already allocated; 266.75 MiB free; 2.49 GiB reserved in total by PyTorch)

Here we can see, no speedup was achieved, and we actually crashed for large batch size.
This is entirely due to non alignment.

The problem can even be made worse, when you have large batch sizes, and RARE very long sentences (https://gist.github.com/Narsil/357519fd385d864bfec3caf5aa8df575).

------------------------------
Streaming no batching
100%|█████████████████████████████████████████████████████████████████████| 1000/1000 [00:05<00:00, 183.69it/s]
------------------------------
Streaming batch_size=8
100%|█████████████████████████████████████████████████████████████████████| 1000/1000 [00:03<00:00, 265.74it/s]
------------------------------
Streaming batch_size=64
100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [00:26<00:00, 37.80it/s]
------------------------------
Streaming batch_size=256
  0%|                                                                                 | 0/1000 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/nicolas/src/transformers/test.py", line 42, in <module>
    for out in tqdm.tqdm(pipe(dataset, batch_size=256), total=len(dataset)):
....
    q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_length, dim_per_head)
RuntimeError: CUDA out of memory. Tried to allocate 376.00 MiB (GPU 0; 3.95 GiB total capacity; 1.72 GiB already allocated; 354.88 MiB free; 2.46 GiB reserved in total by PyTorch)

Here we are actually 5x SLOWER on the batch_size=64 than on the non batched version. That is because the rare long sentence is so long, it actually forces the whole batch to be pad to its sequence length, and use much more memory and processing power (the padding tokens ARE processed by the GPU, they just don't influence the end result).

For users, a rule of thumb is:

  • Measure performance on your load, with your hardware. Measure, measure, and keep measuring. Real numbers are the only way to go.
  • If you are latency constrained (live product doing inference), don't batch
  • If you are using CPU, don't batch.
  • If you are using throughput (you want to run your model on a bunch of static data), on GPU, then:
    • If you have no clue about the size of the sequence_length ("natural" data), by default don't batch, measure and try tentatively to add it, add OOM checks to recover when it will fail (and it will at some point if you don't control the sequence_length.)
    • If your sequence_length is super regular, then batching is more likely to be VERY interesting, measure and push it until you get OOMs.
    • The larger the GPU the more likely batching is going to be more interesting
  • As soon as you enable batching, make sure you can handle OOMs nicely.

There are no good (general) solutions for this problem, and your mileage may vary depending on your use cases. Which is why for now:

  • batch_size=1 by default (both for speed and OOM, issues we can't guess the correct parameters, at least with batch_size=1 we have the smallest chance possible to go OOM).
  • batch_size = 1 is somehow comparable in speed to batched data with irregular data sizes (which is an important use case, like live products where latency also matters).
  • Other batch_sizes are opt-in, because it might be valuable for users to use it (for instance when checking some metric on some dataset which has very regular input lengths, but then it's a user responsibility to check for OOM and slowness).
  • batch_size > 1 won't work for tokenizer/feature_processor that don't have a padding mecanism (if they require it).

It would be ideal if pipelines could start taking that responsibility on its shoulders and start batching dynamically for users but it's a hard problem right now:

  • It's hard to evaluate OOM, and OOM might happen late (so batch_size will always to have to be somehow dynamic during the streaming process)
  • It's even harder to evaluate the slowness factor due to padding, pipelines would have to count them, do some kind of batch exclusion mecanism.
  • Padding issue could be helped quite a bit with RaggedTensors, however, they also don't play that nicely with the GPU capabilities (which need as much aligned/regular data as possible.

Some other links/issues/discussions:

#11251
https://discuss.huggingface.co/t/how-to-change-the-batch-size-in-a-pipeline/8738
https://discuss.huggingface.co/t/how-to-make-pipeline-automatically-scale/7432
#13141
#12195
https://gist.github.com/Narsil/ee5c09875e74fa6f018dc6d014f6c06c

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@LysandreJik @sgugger

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Narsil Narsil changed the title Adding batch_size support for (almost) all pipelines [WIP] Adding batch_size support for (almost) all pipelines Sep 24, 2021
@Narsil Narsil force-pushed the pipeline_batch_size_support branch from 354bb4c to ecedb2e Compare October 11, 2021 08:59
@Narsil Narsil changed the title [WIP] Adding batch_size support for (almost) all pipelines Adding batch_size support for (almost) all pipelines Oct 11, 2021
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a fantastic PR and write-up! Thanks for doing all of the work.

The code looks okay to me, but there are a lot of small changes across pipelines - would it be possible to add comments where those changes are unintuitive so that we may better understand the need addressed? I added comments where I think those would be helpful.

The test changes are clean. Thanks for adding this layer which should make testing simpler for new pipelines.

Finally, the write-up is great, it would be ideal to add it to the documentation. Can you add it to the pipeline RST document?

Comment on lines 683 to 689
k: element[self._unbatch_index].unsqueeze(0)
if isinstance(element[self._unbatch_index], torch.Tensor)
else np.expand_dims(element[self._unbatch_index], 0)
if isinstance(element[self._unbatch_index], np.ndarray)
else element[self._unbatch_index]
for k, element in self._unbatch_data.items()
if k != "past_key_values"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oof this is a tough one to understand, it would be nice to spread it over different lines.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewrote it, hopefully it's better now, can you confirm ?

Comment on lines 70 to 108
raise ValueError("Pipeline without tokenizer or feature_extractor cannot do batching")
if tokenizer is not None:
if tokenizer.pad_token_id is None:
raise ValueError("Pipeline with tokenizer without pad_token cannot do batching")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice error raising! It would be nice to show how to attribute a padding token in that case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a padding_token that simply ? Wouldn't it be erasing an existing (likely used) token ?

Not sure what you mean.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generally show how to do this with the following:

model.config.pad_token_id = model.config.eos_token_id

I think this is particularly important for the pipeline as users don't necessarily understand what/how to change the underlying model's attributes, so printing an example of that in the console would be helpful

src/transformers/pipelines/base.py Outdated Show resolved Hide resolved
src/transformers/pipelines/base.py Outdated Show resolved Hide resolved
src/transformers/pipelines/base.py Outdated Show resolved Hide resolved
src/transformers/pipelines/conversational.py Show resolved Hide resolved
src/transformers/pipelines/question_answering.py Outdated Show resolved Hide resolved
src/transformers/pipelines/zero_shot_classification.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for all your work on this!

docs/source/main_classes/pipelines.rst Show resolved Hide resolved
docs/source/main_classes/pipelines.rst Outdated Show resolved Hide resolved
docs/source/main_classes/pipelines.rst Show resolved Hide resolved
docs/source/main_classes/pipelines.rst Show resolved Hide resolved
src/transformers/pipelines/base.py Outdated Show resolved Hide resolved
Comment on lines 700 to 701
self._unbatch_index = None
self._unbatch_data = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For those other variables, I would prefer unpack_xxx to unbatch personally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I make the switch, I will make the switch for all variables so unpack_size as I consider that these are completely linked, so using similar name is important.

I am fine with the name, even if I feel we loose the connection to the batch concept.

Is that what you are implying ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With respect to other comment, I updated everything to loader_batch_* which is better I think.

src/transformers/pipelines/base.py Outdated Show resolved Hide resolved
src/transformers/pipelines/question_answering.py Outdated Show resolved Hide resolved
src/transformers/pipelines/token_classification.py Outdated Show resolved Hide resolved
src/transformers/pipelines/zero_shot_classification.py Outdated Show resolved Hide resolved
@Narsil Narsil force-pushed the pipeline_batch_size_support branch 2 times, most recently from 376923a to e2d6a6a Compare October 18, 2021 12:11
@Narsil Narsil mentioned this pull request Oct 25, 2021
4 tasks
output_ids = self.model.generate(**model_inputs, **generate_kwargs)
if self.model.config.is_encoder_decoder:
start_position = 1
else:
start_position = n
return {"output_ids": output_ids[0, start_position:], "conversation": conversation}
return {"output_ids": output_ids[:, start_position:], "conversation": conversation}
Copy link
Contributor Author

@Narsil Narsil Oct 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are changing the inference between forward and postprocess the have the batch in the tensors so batch/unbatch can happen.

@@ -204,26 +204,29 @@ def _forward(self, model_inputs):
offset_mapping = model_inputs.pop("offset_mapping", None)
sentence = model_inputs.pop("sentence")
if self.framework == "tf":
outputs = self.model(model_inputs.data)[0][0]
logits = self.model(model_inputs.data)[0]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are changing the inference between forward and postprocess the have the batch in the tensors so batch/unbatch can happen.

sentence = model_outputs["sentence"]
input_ids = model_outputs["input_ids"][0]
offset_mapping = model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None
special_tokens_mask = model_outputs["special_tokens_mask"][0].numpy()

scores = np.exp(outputs) / np.exp(outputs).sum(-1, keepdims=True)
maxes = np.max(logits, axis=-1, keepdims=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logits trick

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding that

filename = dataset[0]["file"]
output = audio_classifier(filename)
audio = dataset[0]["audio"]["array"]
output = audio_classifier(audio)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not relying on a filename anymore since the tests don't run ffmpeg anymore.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, thank you @Narsil. If it works for you, I would like for this PR to be merged after the v4.12.0 release (tomorrow Thursday) so that it gets a bit of testing on master before being set in stone.

Comment on lines 70 to 108
raise ValueError("Pipeline without tokenizer or feature_extractor cannot do batching")
if tokenizer is not None:
if tokenizer.pad_token_id is None:
raise ValueError("Pipeline with tokenizer without pad_token cannot do batching")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generally show how to do this with the following:

model.config.pad_token_id = model.config.eos_token_id

I think this is particularly important for the pipeline as users don't necessarily understand what/how to change the underlying model's attributes, so printing an example of that in the console would be helpful

src/transformers/pipelines/base.py Outdated Show resolved Hide resolved
src/transformers/pipelines/base.py Outdated Show resolved Hide resolved
src/transformers/pipelines/base.py Outdated Show resolved Hide resolved
sentence = model_outputs["sentence"]
input_ids = model_outputs["input_ids"][0]
offset_mapping = model_outputs["offset_mapping"][0] if model_outputs["offset_mapping"] is not None else None
special_tokens_mask = model_outputs["special_tokens_mask"][0].numpy()

scores = np.exp(outputs) / np.exp(outputs).sum(-1, keepdims=True)
maxes = np.max(logits, axis=-1, keepdims=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding that

Narsil and others added 14 commits October 29, 2021 11:01
- Not `zero-shot` (it's already passing stuff as batched so trickier)
- Not `QA` (preprocess uses squad features, we need to switch to real
tensors at this boundary.
and adressing comments.
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@Narsil Narsil force-pushed the pipeline_batch_size_support branch from 09a1db8 to 5cd831c Compare October 29, 2021 09:01
@Narsil
Copy link
Contributor Author

Narsil commented Oct 29, 2021

Release done, merging.

@Narsil Narsil merged commit be23636 into huggingface:master Oct 29, 2021
@Narsil Narsil deleted the pipeline_batch_size_support branch October 29, 2021 09:34
Albertobegue pushed a commit to Albertobegue/transformers that referenced this pull request Jan 27, 2022
…3724)

* Tentative enabling of `batch_size` for pipelines.

* Add systematic test for pipeline batching.

* Enabling batch_size on almost all pipelines

- Not `zero-shot` (it's already passing stuff as batched so trickier)
- Not `QA` (preprocess uses squad features, we need to switch to real
tensors at this boundary.

* Adding `min_length_for_response` for conversational.

* Making CTC, speech mappings avaiable regardless of framework.

* Attempt at fixing automatic tests (ffmpeg not enabled for fast tests)

* Removing ffmpeg dependency in tests.

* Small fixes.

* Slight cleanup.

* Adding docs

and adressing comments.

* Quality.

* Update docs/source/main_classes/pipelines.rst

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/pipelines/question_answering.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/pipelines/zero_shot_classification.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Improving docs.

* Update docs/source/main_classes/pipelines.rst

Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>

* N -> oberved_batch_size

softmax trick.

* Follow `padding_side`.

* Supporting image pipeline batching (and padding).

* Rename `unbatch` -> `loader_batch`.

* unbatch_size forgot.

* Custom padding for offset mappings.

* Attempt to remove librosa.

* Adding require_audio.

* torchaudio.

* Back to using datasets librosa.

* Adding help to set a pad_token on the tokenizer.

* Update src/transformers/pipelines/base.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/pipelines/base.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/pipelines/base.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Quality.

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants