Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove message of checking apex.amp module and add tests for features of gradient accumulation/mixed precision training #46

Merged
merged 6 commits into from Jun 7, 2020

Conversation

NaleRaphael
Copy link
Contributor

@NaleRaphael NaleRaphael commented Jun 6, 2020

This PR is a fix for issue #45 with some new test cases for those features implemented in PR #9.

A quick summary for this PR:

  1. The message To enable mixed precision training, please install apex... is removed. (solved in commit 227fc53)
  2. A silly mistake was made. batch_size was not passed into DataLoader, so that those data loaders in test cases were working with the default value batch_size=1 before. Though it does not affect the test correctness, it still needs to be corrected. (solved in commit 1c549ec)
  3. Another mistake. In test cases, there is no setting that moves model to device declared in task.__init__(). Therefore, all tests were running on CPU even if the pytest argument --cpu_only is not specified. (solved in commit c854714)
  4. More tests for features of gradient accumulation and mixed precision are added in commit e073ac8.

Note that there is a new dependency pytest-mock added for new test cases.

The original propose of that message is to let users know gradient
accumulation and mixed precision training is supported but `apex`
is required.

With an attention brought up by issue davidtvs#45, the following things are
confirmed:

- Gradient accumulation can still work properly without `apex.amp`.
  And that's why it would fall back on normal `loss.backward()` when
  `apex.amp` is not available or `amp.initialize()` wasn't called.

- When mixed precision training is required, that is to say model
  and optimizer are wrapped by `amp.initialize()`, `amp.scale_loss()`
  will be adopted automatically in current implementation.

Therefore, it seems that message of checking `apex.amp` module is
not necessary anymore.
This mistake made batch size of every data loader become the
default value: 1. Though it does not affect the correctness of
all test case, it still needs to be corrected.

However, `batch_size` of a `DataLoader` cannot be modified
after it is initialized. Therefore, we can only determine it
while generating tasks for test, and that's why `batch_size`
and `steps` is moved to the signature of `__init__` of each
`Task`.
This functionality was not added before, and it made all tests run
on CPU even if the pytest argument `--cpu_only` is not specified.
Copy link
Owner

@davidtvs davidtvs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for another good contribution

@@ -4,6 +4,14 @@
import task as mod_task


try:
import apex
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line can be changed to from apex import amp and we can then remove the local imports from the functions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'll fix it.

reason="`apex` module and gpu is required to run this test."
)
def test_gradient_accumulation_with_apex_amp(self, mocker):
from apex import amp
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove line (see comment about import)

)
class TestMixedPrecision:
def test_mixed_precision(self, mocker):
from apex import amp
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove line (see comment about import)

Copy link
Owner

@davidtvs davidtvs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to select the proper radio button

@davidtvs
Copy link
Owner

davidtvs commented Jun 6, 2020

/black-check

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No linting violations have been found in this PR.

@davidtvs
Copy link
Owner

davidtvs commented Jun 6, 2020

/flake8-lint

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lintly has detected code quality issues in this pull request.

@@ -4,6 +4,14 @@
import task as mod_task


try:
import apex
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

F401: 'apex' imported but unused

Copy link
Owner

@davidtvs davidtvs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work. Thanks!

@davidtvs davidtvs merged commit 23a23cf into davidtvs:master Jun 7, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants