Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batching in TokenClassificationPipeline #11251

Closed
wants to merge 9 commits into from

Conversation

parakalan
Copy link
Contributor

@parakalan parakalan commented Apr 14, 2021

What does this PR do?

Currently, the NER pipeline in transformers iterates through the list of input sentences and processes them, sequentially.
This PR adds batching support in the pipeline to decrease latency and use GPU more efficiently.

Relevant Issue :- #11244

Benchmark Report

Without Batching (CPU)

Device: CPU
No. examples: 1000
Time taken: 283.27826976776123

Device: GPU
No. examples: 1000
Time taken: 17.89318561553955

Please check the benchmark gist here

Without Batching (CPU)

Device: CPU
No. examples: 1000
Batch Size: 512
Time taken: 121.81582999229431

Device: GPU
No. examples: 1000
Batch Size: 512
Time taken: 2.780881404876709

Please check the benchmark gist here

Before submitting

This was referenced Apr 15, 2021
@parakalan parakalan changed the title WIP: Add batching in TokenClassificationPipeline Add batching in TokenClassificationPipeline Apr 15, 2021
@LysandreJik
Copy link
Member

FYI there is also work done on this pipeline in #10568 if you want to give it a look! It doesn't concern batching, however.

@parakalan
Copy link
Contributor Author

Thanks, let me check that out.

@parakalan
Copy link
Contributor Author

Please review this @LysandreJik , @Narsil , @joshdevins

Copy link
Contributor

@Narsil Narsil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I appreciate the intent of this PR, but I think it's most likely to hit users in their own foot without them even realizing. I'm disapproving this PR for that reason (and it actually raises questions about other pipelines implementations)

Model batch size 1  
Device: CPU (i7-4790)
No. examples: 200
Time taken: 10.693870067596436
Model batch size 2
Device: CPU (batched)
No. examples: 200
Time taken: 10.754503965377808
Model batch size 2
Device: CPU (batched 2nd)
No. examples: 200
Time taken: 34.8472113609314
-------------------------------------------------- (GTX 970)
Model batch size 1
Device: GPU
No. examples: 200
Time taken: 0.7129480838775635
Model batch size 2
Device: GPU (batched)
No. examples: 200
Time taken: 0.708709716796875
Model batch size 2
Device: GPU (batched 2nd)
No. examples: 200
Time taken: 8.97895359992981

The core of the issue is the extra padding tokens created while running inference. Those are very bad for overall efficiency, and in a live system it is most likely to occur (from experience, I can say that it can be overwhelmingly bad). It's almost never worth it to attempt batching in a live production, unless you're very sure about the alignment problem.

  • That being noted, other pipelines do sometimes do batching, and it can be effective if used properly, even if it is very hard at the pipeline level (because we're receiving strings and not TOKENS, which should be used for nice alignment)

@@ -30,6 +31,7 @@ class TokenClassificationArgumentHandler(ArgumentHandler):

def __call__(self, inputs: Union[str, List[str]], **kwargs):

model_batch_size = kwargs.get("model_batch_size", 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are actually using this as a boolean, not an int.

input_ids = tokens["input_ids"].cpu().numpy()[0]
if self.framework == "tf":
if model_batch_size > 1:
warnings.warn("The `model_batch_size` argument is not supported for Tensorflow models. Ignoring")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're actually breaking tensorflow, right ? because you're not checking all sentences anymore.

@parakalan
Copy link
Contributor Author

Closing this PR based on @Narsil's review. Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants