Skip to content

Commit

Permalink
Merge branch 'batchnorm'
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee committed May 26, 2019
2 parents 67ca598 + abebbee commit b403c8b
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 3 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ python:
env:
- PYTORCH=1.0.0
- PYTORCH=1.0.1
- PYTORCH=1.1.0

stages:
- lint
Expand Down
2 changes: 1 addition & 1 deletion README.ko.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ GPipe는 Pipeline Parallelism과 Checkpointing, 두 가지 방법으로 가능
현재 torchgpipe는 다음 환경을 지원합니다:

- Python 3.6 이상
- PyTorch 1.0
- PyTorch 1.0 이상

우선 `torchgpipe`를 PyPI에서 설치합니다:

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ A [GPipe](https://arxiv.org/abs/1811.06965) implementation in PyTorch.
Prerequisites are:

- Python 3.6+
- PyTorch 1.0
- PyTorch 1.0+
- Your `nn.Sequential` module

Install via PyPI:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
packages=['torchgpipe'],
package_data={'torchgpipe': ['py.typed']},

install_requires=['torch>=1,<1.1'],
install_requires=['torch>=1'],
setup_requires=['pytest-runner'],
tests_require=['pytest>=4'],

Expand Down
8 changes: 8 additions & 0 deletions torchgpipe/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def forward_pre_hook(self, bn: BatchNorm, inputs: Tuple[Tensor, ...]) -> None:

# Don't track the running stats of this batch. It is already deferred.
bn.track_running_stats = False
bn.momentum_ = bn.momentum
bn.momentum = None

# Skip if this forward pass is triggered by checkpoint recomputation.
if is_recomputing():
Expand All @@ -72,6 +74,12 @@ def forward_pre_hook(self, bn: BatchNorm, inputs: Tuple[Tensor, ...]) -> None:
def forward_hook(self, bn: BatchNorm, input: Tensor, output: Tensor) -> None:
# Any internal state modified by this hook should not be visible to users.
bn.track_running_stats = True
try:
bn.momentum = bn.momentum_
except AttributeError:
pass
else:
del bn.momentum_

def backward_hook(self, bn: BatchNorm,
grad_input: Tensor,
Expand Down

0 comments on commit b403c8b

Please sign in to comment.