Skip to content

Commit

Permalink
Adding batch_size support for (almost) all pipelines (#13724)
Browse files Browse the repository at this point in the history
* Tentative enabling of `batch_size` for pipelines.

* Add systematic test for pipeline batching.

* Enabling batch_size on almost all pipelines

- Not `zero-shot` (it's already passing stuff as batched so trickier)
- Not `QA` (preprocess uses squad features, we need to switch to real
tensors at this boundary.

* Adding `min_length_for_response` for conversational.

* Making CTC, speech mappings avaiable regardless of framework.

* Attempt at fixing automatic tests (ffmpeg not enabled for fast tests)

* Removing ffmpeg dependency in tests.

* Small fixes.

* Slight cleanup.

* Adding docs

and adressing comments.

* Quality.

* Update docs/source/main_classes/pipelines.rst

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/pipelines/question_answering.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/pipelines/zero_shot_classification.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Improving docs.

* Update docs/source/main_classes/pipelines.rst

Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>

* N -> oberved_batch_size

softmax trick.

* Follow `padding_side`.

* Supporting image pipeline batching (and padding).

* Rename `unbatch` -> `loader_batch`.

* unbatch_size forgot.

* Custom padding for offset mappings.

* Attempt to remove librosa.

* Adding require_audio.

* torchaudio.

* Back to using datasets librosa.

* Adding help to set a pad_token on the tokenizer.

* Update src/transformers/pipelines/base.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/pipelines/base.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/pipelines/base.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Quality.

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 29, 2021
1 parent 4469010 commit be23636
Show file tree
Hide file tree
Showing 27 changed files with 629 additions and 64 deletions.
143 changes: 143 additions & 0 deletions docs/source/main_classes/pipelines.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ GPU. If it doesn't don't hesitate to create an issue.

.. code-block::
import datasets
from transformers import pipeline
from transformers.pipelines.base import KeyDataset
import tqdm
pipe = pipeline("automatic-speech-recognition", model="facebook/wav2vec2-base-960h", device=0)
dataset = datasets.load_dataset("superb", name="asr", split="test")
Expand All @@ -85,6 +90,144 @@ GPU. If it doesn't don't hesitate to create an issue.
.. autofunction:: transformers.pipeline

Pipeline batching
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

All pipelines (except `zero-shot-classification` and `question-answering` currently) can use batching. This will work
whenever the pipeline uses its streaming ability (so when passing lists or :obj:`Dataset`).

.. code-block::
from transformers import pipeline
from transformers.pipelines.base import KeyDataset
import datasets
import tqdm
dataset = datasets.load_dataset("imdb", name="plain_text", split="unsupervised")
pipe = pipeline("text-classification", device=0)
for out in pipe(KeyDataset(dataset, "text"), batch_size=8, truncation="only_first"):
print(out)
# [{'label': 'POSITIVE', 'score': 0.9998743534088135}]
# Exactly the same output as before, but the content are passed
# as batches to the model
.. warning::

However, this is not automatically a win for performance. It can be either a 10x speedup or 5x slowdown depending
on hardware, data and the actual model being used.

Example where it's most a speedup:


.. code-block::
from transformers import pipeline
from torch.utils.data import Dataset
import tqdm
pipe = pipeline("text-classification", device=0)
class MyDataset(Dataset):
def __len__(self):
return 5000
def __getitem__(self, i):
return "This is a test"
dataset = MyDataset()
for batch_size in [1, 8, 64, 256]:
print("-" * 30)
print(f"Streaming batch_size={batch_size}")
for out in tqdm.tqdm(pipe(dataset, batch_size=batch_size), total=len(dataset)):
pass
.. code-block::
# On GTX 970
------------------------------
Streaming no batching
100%|██████████████████████████████████████████████████████████████████████| 5000/5000 [00:26<00:00, 187.52it/s]
------------------------------
Streaming batch_size=8
100%|█████████████████████████████████████████████████████████████████████| 5000/5000 [00:04<00:00, 1205.95it/s]
------------------------------
Streaming batch_size=64
100%|█████████████████████████████████████████████████████████████████████| 5000/5000 [00:02<00:00, 2478.24it/s]
------------------------------
Streaming batch_size=256
100%|█████████████████████████████████████████████████████████████████████| 5000/5000 [00:01<00:00, 2554.43it/s]
(diminishing returns, saturated the GPU)
Example where it's most a slowdown:

.. code-block::
class MyDataset(Dataset):
def __len__(self):
return 5000
def __getitem__(self, i):
if i % 64 == 0:
n = 100
else:
n = 1
return "This is a test" * n
This is a occasional very long sentence compared to the other. In that case, the **whole** batch will need to be 400
tokens long, so the whole batch will be [64, 400] instead of [64, 4], leading to the high slowdown. Even worse, on
bigger batches, the program simply crashes.


.. code-block::
------------------------------
Streaming no batching
100%|█████████████████████████████████████████████████████████████████████| 1000/1000 [00:05<00:00, 183.69it/s]
------------------------------
Streaming batch_size=8
100%|█████████████████████████████████████████████████████████████████████| 1000/1000 [00:03<00:00, 265.74it/s]
------------------------------
Streaming batch_size=64
100%|██████████████████████████████████████████████████████████████████████| 1000/1000 [00:26<00:00, 37.80it/s]
------------------------------
Streaming batch_size=256
0%| | 0/1000 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/home/nicolas/src/transformers/test.py", line 42, in <module>
for out in tqdm.tqdm(pipe(dataset, batch_size=256), total=len(dataset)):
....
q = q / math.sqrt(dim_per_head) # (bs, n_heads, q_length, dim_per_head)
RuntimeError: CUDA out of memory. Tried to allocate 376.00 MiB (GPU 0; 3.95 GiB total capacity; 1.72 GiB already allocated; 354.88 MiB free; 2.46 GiB reserved in total by PyTorch)
There are no good (general) solutions for this problem, and your mileage may vary depending on your use cases. Rule of
thumb:

For users, a rule of thumb is:

- **Measure performance on your load, with your hardware. Measure, measure, and keep measuring. Real numbers are the
only way to go.**
- If you are latency constrained (live product doing inference), don't batch
- If you are using CPU, don't batch.
- If you are using throughput (you want to run your model on a bunch of static data), on GPU, then:

- If you have no clue about the size of the sequence_length ("natural" data), by default don't batch, measure and
try tentatively to add it, add OOM checks to recover when it will fail (and it will at some point if you don't
control the sequence_length.)
- If your sequence_length is super regular, then batching is more likely to be VERY interesting, measure and push
it until you get OOMs.
- The larger the GPU the more likely batching is going to be more interesting
- As soon as you enable batching, make sure you can handle OOMs nicely.



Implementing a pipeline
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,7 @@
[
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING",
"MODEL_FOR_CAUSAL_LM_MAPPING",
"MODEL_FOR_CTC_MAPPING",
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING",
"MODEL_FOR_IMAGE_SEGMENTATION_MAPPING",
"MODEL_FOR_MASKED_LM_MAPPING",
Expand All @@ -594,6 +595,7 @@
"MODEL_FOR_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING",
"MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING",
"MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING",
"MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING",
"MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING",
"MODEL_MAPPING",
Expand Down Expand Up @@ -2430,6 +2432,7 @@
from .models.auto import (
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
MODEL_FOR_CAUSAL_LM_MAPPING,
MODEL_FOR_CTC_MAPPING,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
MODEL_FOR_IMAGE_SEGMENTATION_MAPPING,
MODEL_FOR_MASKED_LM_MAPPING,
Expand All @@ -2440,6 +2443,7 @@
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_MAPPING,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/pipelines/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def _forward(self, model_inputs):
elif model_class in MODEL_FOR_CTC_MAPPING.values():
outputs = self.model(**model_inputs)
tokens = outputs.logits.squeeze(0).argmax(dim=-1)
else:
logger.warning("This is an unknown class, treating it as CTC.")
outputs = self.model(**model_inputs)
tokens = outputs.logits.squeeze(0).argmax(dim=-1)
return tokens

def postprocess(self, model_outputs):
Expand Down
Loading

0 comments on commit be23636

Please sign in to comment.