Skip to content

Commit

Permalink
Redesign DeferredBatchNorm
Browse files Browse the repository at this point in the history
Hook-based -> Subclass-based.

Inspired by the implementation of SyncBatchNorm, introduced at PyTorch 1.1.
  • Loading branch information
sublee authored and GitHub Enterprise committed Jun 7, 2019
1 parent 7b1d38f commit bf51a06
Show file tree
Hide file tree
Showing 6 changed files with 252 additions and 175 deletions.
4 changes: 2 additions & 2 deletions stubs/torch/nn/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#MODIFIED BY TORCHGPIPE
from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TypeVar
from typing import Any, Callable, Iterable, Iterator, Optional, Tuple, TypeVar, Union

from torch import Tensor, device

Expand Down Expand Up @@ -34,7 +34,7 @@ class Module:
def apply(self, fn: Callable[[Module], None]) -> Module: ...

def register_buffer(self, name: str, tensor: Tensor) -> None: ...
def register_parameter(self, name: str, param: Parameter) -> None: ...
def register_parameter(self, name: str, param: Union[Parameter, None]) -> None: ...

def register_backward_hook(self, hook: __Hook2) -> __RemovableHandle: ...
def register_forward_pre_hook(self, hook: __Hook1) -> __RemovableHandle: ...
Expand Down
6 changes: 3 additions & 3 deletions stubs/torch/nn/modules/batchnorm.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Any, Callable, Iterable, Iterator, Optional

from torch import Tensor
from torch.nn import Module
from torch.nn import Module, Parameter


class _BatchNorm(Module):
Expand All @@ -12,8 +12,8 @@ class _BatchNorm(Module):
affine: bool
track_running_stats: bool

weight: Tensor
bias: Tensor
weight: Parameter
bias: Parameter
running_mean: Tensor
running_var: Tensor
num_batches_tracked: Tensor
Expand Down
157 changes: 83 additions & 74 deletions tests/test_deferred_batch_norm.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,14 @@
from copy import deepcopy
from itertools import chain

import pytest
import torch
import torch.nn as nn
import torch.optim as optim

from torchgpipe.batchnorm import patch_deferred_batch_norm
from torchgpipe.batchnorm import DeferredBatchNorm


@pytest.fixture
def bn():
torch.manual_seed(0)
return nn.BatchNorm2d(3)


@pytest.fixture
def def_bn():
torch.manual_seed(0)
module = nn.BatchNorm2d(3)
module.apply(patch_deferred_batch_norm)
return module
CHUNKS = 4


def tilt_dist(input):
Expand All @@ -31,69 +20,97 @@ def tilt_dist(input):

# Tilt mean by single batch.
for i, single in enumerate(input):
single += 10**i
single += 2**i

return input


def chunked_forward(model, input):
def chunked_forward(model, input, chunks=CHUNKS):
output_chunks = []

for chunk in input.chunk(4):
for chunk in input.chunk(chunks):
output_chunks.append(model(chunk))

return torch.cat(output_chunks)


def test_running_stats(bn, def_bn):
input = torch.rand(16, 3, 224, 224)
input = tilt_dist(input)
@pytest.mark.parametrize('chunks', [1, 4])
@pytest.mark.parametrize('input_requires_grad', [True, False])
def test_transparency(chunks, input_requires_grad):
bn = nn.BatchNorm2d(3)
dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=chunks)

bn(input)
y = chunked_forward(def_bn, input)
y.sum().backward() # flush buffer
input1 = torch.rand(16, 3, 224, 224)
input1 = tilt_dist(input1)
input2 = input1.clone()
input1.requires_grad = input_requires_grad
input2.requires_grad = input_requires_grad

assert torch.allclose(bn.running_mean, def_bn.running_mean, atol=1e-4)
assert torch.allclose(bn.running_var, def_bn.running_var, atol=1e-4)
output1 = chunked_forward(bn, input1, chunks=chunks)
output2 = chunked_forward(dbn, input2, chunks=chunks)

assert torch.allclose(output1, output2, atol=1e-4)

def test_noop():
bn = nn.BatchNorm2d(3, track_running_stats=False)
bn.apply(patch_deferred_batch_norm)
y = bn(torch.rand(16, 3, 224, 224))
y.mean().backward()
output1.mean().backward()
output2.mean().backward()

assert torch.allclose(bn.weight.grad, dbn.weight.grad, atol=1e-4)

if input_requires_grad:
assert input1.grad is not None
assert input2.grad is not None
assert torch.allclose(input1.grad, input2.grad, atol=1e-4)


def test_eval(bn, def_bn):
@pytest.mark.parametrize('momentum', [0.1, None])
def test_running_stats(momentum):
bn = nn.BatchNorm2d(3, momentum=momentum)
dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)

input = torch.rand(16, 3, 224, 224)
input = tilt_dist(input)

bn(input)
y = chunked_forward(def_bn, input)
y.sum().backward() # flush buffer
chunked_forward(dbn, input)

assert torch.allclose(bn.running_mean, dbn.running_mean, atol=1e-4)
assert torch.allclose(bn.running_var, dbn.running_var, atol=1e-4)

bn.eval()
def_bn.eval()

assert torch.allclose(bn(input), def_bn(input), atol=1e-4)
def test_convert_deferred_batch_norm():
bn = nn.BatchNorm2d(3, track_running_stats=False)
bn = DeferredBatchNorm.convert_deferred_batch_norm(bn, chunks=CHUNKS)
assert type(bn) is nn.BatchNorm2d # because of track_running_stats=False

dbn = DeferredBatchNorm(3, chunks=CHUNKS)
dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS)
assert dbn is dbn_again

dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS + 1)
assert dbn is not dbn_again # because of different chunks


def test_backward(def_bn):
def test_eval():
bn = nn.BatchNorm2d(3)
dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)

input = torch.rand(16, 3, 224, 224)
input = tilt_dist(input)

output = chunked_forward(def_bn, input)
bn(input)
chunked_forward(dbn, input)

bn.eval()
dbn.eval()

assert torch.allclose(bn(input), dbn(input), atol=1e-4)

# Should not raise this error:
#
# RuntimeError: one of the variables needed for gradient computation has
# been modified by an inplace operation
#
output.sum().backward()

def test_optimize():
bn = nn.BatchNorm2d(3)
dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)

def test_optimize(bn, def_bn):
opt = optim.SGD(chain(bn.parameters(), def_bn.parameters()), lr=0.1)
opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=1.0)

for i in range(5):
input = torch.rand(16, 3, 224, 224)
Expand All @@ -104,42 +121,32 @@ def test_optimize(bn, def_bn):
a = y.sum()
a.backward()

y = chunked_forward(def_bn, input)
y = chunked_forward(dbn, input)
b = y.sum()
b.backward()

opt.step()

# eval
bn.eval()
def_bn.eval()
dbn.eval()

with torch.no_grad():
assert torch.allclose(bn(input), def_bn(input), atol=1e-1 * (10**i))
assert torch.allclose(bn(input), dbn(input), atol=1e-1 * (10**i))


def test_conv_bn():
torch.manual_seed(0)
bn = nn.Sequential(
nn.Conv2d(3, 3, 1),
nn.BatchNorm2d(3),
)

torch.manual_seed(0)
def_bn = nn.Sequential(
nn.Conv2d(3, 3, 1),
nn.BatchNorm2d(3),
)
def_bn.apply(patch_deferred_batch_norm)

opt = optim.SGD(chain(bn.parameters(), def_bn.parameters()), lr=0.1)
bn = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3))
dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)

input = torch.rand(16, 3, 224, 224)
input = tilt_dist(input)

opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=0.1)

# 1st step
a = bn(input)
b = chunked_forward(def_bn, input)
b = chunked_forward(dbn, input)

# Outputs are different. (per-mini-batch vs. per-micro-batch)
assert not torch.allclose(a, b)
Expand All @@ -150,28 +157,30 @@ def test_conv_bn():
opt.zero_grad()

# Conv layers are also trained differently because of their different outputs.
assert not torch.allclose(bn[0].weight, def_bn[0].weight)
assert not torch.allclose(bn[0].weight, dbn[0].weight)

# But BNs track identical running stats.
assert torch.allclose(bn[1].running_mean, def_bn[1].running_mean, atol=1e+8)
assert torch.allclose(bn[1].running_var, def_bn[1].running_var, atol=1e+22)
assert torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4)
assert torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e+3)

# 2nd step
a = bn(input)
b = chunked_forward(def_bn, input)
b = chunked_forward(dbn, input)
a.sum().backward()
b.sum().backward()

# BNs can't track identical running stats due to the different conv layers.
assert not torch.allclose(bn[1].running_mean, def_bn[1].running_mean, atol=1e+8)
assert not torch.allclose(bn[1].running_var, def_bn[1].running_var, atol=1e+22)
assert not torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4)
assert not torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e+3)


def test_input_requiring_grad():
dbn = DeferredBatchNorm(3, chunks=CHUNKS)

def test_input_requiring_grad(def_bn):
input = torch.rand(16, 3, 224, 224, requires_grad=True)
input = tilt_dist(input)

chunked_forward(def_bn, input)
chunked_forward(dbn, input)

assert not def_bn.sum.requires_grad
assert def_bn.sum.grad_fn is None
assert not dbn.sum.requires_grad
assert dbn.sum.grad_fn is None
28 changes: 23 additions & 5 deletions tests/test_gpipe.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
import time

import pytest
Expand Down Expand Up @@ -417,17 +418,34 @@ def forward(self, x):
@pytest.mark.parametrize('checkpoint', ['never', 'always', 'except_last'])
def test_deferred_batch_norm(checkpoint):
bn = nn.BatchNorm2d(3)
bn_under_gpipe = nn.BatchNorm2d(3)
gpipe_bn = deepcopy(bn)
gpipe = GPipe(nn.Sequential(gpipe_bn), balance=[1], devices=['cpu'], chunks=2,
checkpoint=checkpoint, deferred_batch_norm=True)

x = torch.rand(4, 3, 10, 10)
gpipe(x).mean().backward()
bn(x).mean().backward()

gpipe = GPipe(nn.Sequential(bn_under_gpipe), balance=[1], devices=['cpu'], chunks=2,
assert torch.allclose(gpipe[0].running_mean, bn.running_mean, atol=1e-4)
assert torch.allclose(gpipe[0].running_var, bn.running_var, atol=1e-4)


@pytest.mark.parametrize('checkpoint', ['never', 'always'])
def test_deferred_batch_norm_params(checkpoint):
bn = nn.BatchNorm2d(3)
gpipe_bn = deepcopy(bn)
gpipe = GPipe(nn.Sequential(gpipe_bn), balance=[1], devices=['cpu'], chunks=1,
checkpoint=checkpoint, deferred_batch_norm=True)

x = torch.rand(4, 3, 10, 10)
gpipe(x).mean().backward()
bn(x)
bn(x).mean().backward()

assert gpipe[0].weight.grad is not None
assert gpipe[0].bias.grad is not None

assert torch.allclose(bn_under_gpipe.running_mean, bn.running_mean, atol=1e-4)
assert torch.allclose(bn_under_gpipe.running_var, bn.running_var, atol=1e-4)
assert torch.allclose(gpipe[0].weight.grad, bn.weight.grad, atol=1e-4)
assert torch.allclose(gpipe[0].bias.grad, bn.bias.grad, atol=1e-4)


def test_current_microbatch():
Expand Down

0 comments on commit bf51a06

Please sign in to comment.