Skip to content

Commit

Permalink
MAINT: remove message of checking apex.amp module
Browse files Browse the repository at this point in the history
The original propose of that message is to let users know gradient
accumulation and mixed precision training is supported but `apex`
is required.

With an attention brought up by issue #45, the following things are
confirmed:

- Gradient accumulation can still work properly without `apex.amp`.
  And that's why it would fall back on normal `loss.backward()` when
  `apex.amp` is not available or `amp.initialize()` wasn't called.

- When mixed precision training is required, that is to say model
  and optimizer are wrapped by `amp.initialize()`, `amp.scale_loss()`
  will be adopted automatically in current implementation.

Therefore, it seems that message of checking `apex.amp` module is
not necessary anymore.
  • Loading branch information
NaleRaphael committed Jun 6, 2020
1 parent 98c4004 commit 227fc53
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 11 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,15 @@ accumulation_steps = desired_batch_size // real_batch_size
dataset = ...

# Beware of the `batch_size` used by `DataLoader`
trainloader = DataLoader(dataset, batch_size=real_bs, shuffle=True)
trainloader = DataLoader(dataset, batch_size=real_batch_size, shuffle=True)

model = ...
criterion = ...
optimizer = ...

# (Optional) With this setting, `amp.scale_loss()` will be adopted automatically.
# model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

lr_finder = LRFinder(model, optimizer, criterion, device="cuda")
lr_finder.range_test(trainloader, end_lr=10, num_iter=100, step_mode="exp", accumulation_steps=accumulation_steps)
lr_finder.plot()
Expand Down
10 changes: 0 additions & 10 deletions torch_lr_finder/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,7 @@

IS_AMP_AVAILABLE = True
except ImportError:
import logging

logging.basicConfig()
logger = logging.getLogger(__name__)
logger.warning(
"To enable mixed precision training, please install `apex`. "
"Or you can re-install this package by the following command:\n"
' pip install torch-lr-finder -v --global-option="amp"'
)
IS_AMP_AVAILABLE = False
del logging


class DataLoaderIter(object):
Expand Down

0 comments on commit 227fc53

Please sign in to comment.