Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark Performance for Baseline vs Pipeline-1 #11

Closed
vibhatha opened this issue Feb 26, 2020 · 5 comments
Closed

Benchmark Performance for Baseline vs Pipeline-1 #11

vibhatha opened this issue Feb 26, 2020 · 5 comments
Assignees
Labels
question Further information is requested

Comments

@vibhatha
Copy link

With the speed benchmarks, the pipeline-1 benchmark time is higher than that of baseline benchmarks. Is there a clear reason why there is a significant overhead with pipeline-1 with respect to baseline experiments?

What I understood from the script was the baseline runs in one GPU core. Is this right?
And pipeline also runs in one GPU core? Is this right?

@sublee sublee self-assigned this Feb 26, 2020
@sublee sublee added the question Further information is requested label Feb 26, 2020
@sublee
Copy link
Contributor

sublee commented Feb 26, 2020

Both pipeline-1 and baseline benchmarks run on a single GPU. Unlike baseline, pipeline-1 includes checkpointing which has an overhead. This overhead is worth since there is actual pipeline parallelism, but pipeline-1 does not perform any parallelism.

@vibhatha
Copy link
Author

vibhatha commented Feb 28, 2020

About checkpointing, if I understood right, the overhead comes with re-running the forward at the end of a micro-batch. Correct me if I am wrong.

For checkpointing, there are a couple of modes,

when I try to use 'never' it runs out of memory.

I tried 'except_last', it works fine and in the 'always' option, the performance is not as much as 'except_last'.

So what I used in running pipeline-1 is the except_last, so I get minimum overhead in re-running the forward? Am I right?

And also, if I do use 'always' option, it re-runs for each micro-batch?
And if I use 'except_last' it runs for the last micro-batch?

In addition, what is the difference between 'never' and 'except_last'?

@sublee
Copy link
Contributor

sublee commented Mar 2, 2020

There seems misunderstanding in 'except_last'. When you choose 'except_last', torchgpipe reruns every micro-batch but not the last one. For example, let's assume that we use 8 micro-batches. Then the each option involves rerunning n micro-batches:

  • 'never': 0
  • 'always': 8
  • 'except_last': 7

@vibhatha
Copy link
Author

vibhatha commented Mar 2, 2020

I understand. I will test this for smaller batch sizes. Thanks for the clarification of this point.

@sublee
Copy link
Contributor

sublee commented Mar 2, 2020

You are welcome. I close this issue.

@sublee sublee closed this as completed Mar 2, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants