Skip to content

Commit

Permalink
Experiments for ResNet-101 Accuracy Benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
clint(백운혁) authored and sublee committed Sep 20, 2019
1 parent 5ec5053 commit 921f27b
Show file tree
Hide file tree
Showing 11 changed files with 496 additions and 14 deletions.
14 changes: 14 additions & 0 deletions README.ko.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,20 @@ Parallelism이 적용되지 않고 Checkpointing 오버헤드만 있어서 naive
[examples/resnet101_performance_benchmark](examples/resnet101_performance_benchmark)에서
실험 코드를 확인할 수 있습니다.

### ResNet-101 정확도 벤치마크

배치크기 | torchgpipe | nn.DataParalle | 논문
---- | -----: | -----: | -----:
256 | 21.99±0.13 | 22.02±0.11 | 22.08±0.06
1k | 22.24±0.19 | 22.04±0.24 | N/A
4k | 22.13±0.09 | N/A | N/A

[Accurate, Large Minibatch SGD 논문](https://arxiv.org/abs/1706.02677)의 테이블 2(c)에
보고된 ResNet-101 정확도(오답률) 벤치마크를 재현했습니다.

[examples/resnet101_accuracy_benchmark](examples/resnet101_accuracy_benchmark)에서
실험 코드를 확인할 수 있습니다.

### AmoebaNet-D 속도 벤치마크

실험 | torchgpipe | GPipe (논문)
Expand Down
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,20 @@ overhead.
The reproducible code can be found in
[examples/resnet101_performance_benchmark](examples/resnet101_performance_benchmark).

### ResNet-101 Accuracy Benchmark

batchsize | torchgpipe | nn.DataParallel | paper
---- | -----: | -----: | -----:
256 | 21.99±0.13 | 22.02±0.11 | 22.08±0.06
1k | 22.24±0.19 | 22.04±0.24 | N/A
4k | 22.13±0.09 | N/A | N/A

The table shows the reproduced accuracy(top-1 error rate) benchmark on ResNet-101,
as stated by reported in Table 2(c) of [Accurate, Large Minibatch SGD paper](https://arxiv.org/abs/1706.02677).

The reproducible code can be found in
[examples/resnet101_accuracy_benchmark](examples/resnet101_accuracy_benchmark).

### AmoebaNet-D Performance Benchmark

Experiment | torchgpipe | GPipe (original)
Expand Down
19 changes: 19 additions & 0 deletions docs/benchmarks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,25 @@ can be found in `examples/resnet101_performance_benchmark`_.
.. _examples/resnet101_performance_benchmark:
https://github.com/kakaobrain/torchgpipe/tree/master/examples/resnet101_performance_benchmark

ResNet-101 Accuracy Benchmark
-----------------------------

================ ===============
Experiment top-1 error (%)
================ ===============
dataparallel-256 22.02±0.11
dataparallel-1k 22.04±0.24
pipeline-256 21.99±0.13
pipeline-1k 22.24±0.19
pipeline-4k 22.13±0.09
================ ===============

The code is reproducible on Tesla P40 GPUs, and the experiment details
can be found in `examples/resnet101_accuracy_benchmark`_.

.. _examples/resnet101_accuracy_benchmark:
https://github.com/kakaobrain/torchgpipe/tree/master/examples/resnet101_accuracy_benchmark

AmoebaNet-D
~~~~~~~~~~~

Expand Down
15 changes: 8 additions & 7 deletions examples/resnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

def build_resnet(layers: List[int],
num_classes: int = 1000,
inplace: bool = False
) -> nn.Sequential:
"""Builds a ResNet as a simple sequential model.
Expand All @@ -26,7 +27,7 @@ def build_resnet(layers: List[int],
"""
inplanes = 64

def make_layer(planes: int, blocks: int, stride: int = 1) -> nn.Sequential:
def make_layer(planes: int, blocks: int, stride: int = 1, inplace: bool = False) -> nn.Sequential:
nonlocal inplanes

downsample = None
Expand All @@ -38,10 +39,10 @@ def make_layer(planes: int, blocks: int, stride: int = 1) -> nn.Sequential:
)

layers = []
layers.append(bottleneck(inplanes, planes, stride, downsample))
layers.append(bottleneck(inplanes, planes, stride, downsample, inplace))
inplanes = planes * 4
for _ in range(1, blocks):
layers.append(bottleneck(inplanes, planes))
layers.append(bottleneck(inplanes, planes, inplace=inplace))

return nn.Sequential(*layers)

Expand All @@ -52,10 +53,10 @@ def make_layer(planes: int, blocks: int, stride: int = 1) -> nn.Sequential:
('relu', nn.ReLU()),
('maxpool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),

('layer1', make_layer(64, layers[0])),
('layer2', make_layer(128, layers[1], stride=2)),
('layer3', make_layer(256, layers[2], stride=2)),
('layer4', make_layer(512, layers[3], stride=2)),
('layer1', make_layer(64, layers[0], inplace=inplace)),
('layer2', make_layer(128, layers[1], stride=2, inplace=inplace)),
('layer3', make_layer(256, layers[2], stride=2, inplace=inplace)),
('layer4', make_layer(512, layers[3], stride=2, inplace=inplace)),

('avgpool', nn.AdaptiveAvgPool2d((1, 1))),
('flat', nn.Flatten()),
Expand Down
10 changes: 6 additions & 4 deletions examples/resnet/bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,22 +75,24 @@ def bottleneck(inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
inplace: bool = False,
) -> nn.Sequential:
"""Creates a bottlenect block in ResNet as a :class:`nn.Sequential`."""
"""Creates a bottleneck block in ResNet as a :class:`nn.Sequential`."""

layers: NamedModules = OrderedDict()
layers['twin'] = Twin()

layers['conv1'] = Gutter(conv1x1(inplanes, planes))
layers['bn1'] = Gutter(nn.BatchNorm2d(planes))
layers['relu1'] = Gutter(nn.ReLU())
layers['relu1'] = Gutter(nn.ReLU(inplace=inplace))

layers['conv2'] = Gutter(conv3x3(planes, planes, stride))
layers['bn2'] = Gutter(nn.BatchNorm2d(planes))
layers['relu2'] = Gutter(nn.ReLU())
layers['relu2'] = Gutter(nn.ReLU(inplace=inplace))

layers['conv3'] = Gutter(conv1x1(planes, planes * 4))
layers['bn3'] = Gutter(nn.BatchNorm2d(planes * 4))
layers['residual'] = Residual(downsample)
layers['relu3'] = nn.ReLU()
layers['relu3'] = nn.ReLU(inplace=inplace)

return nn.Sequential(layers)
59 changes: 59 additions & 0 deletions examples/resnet101_accuracy_benchmark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# ResNet-101 Accuracy Benchmark

This example reproduces accuracy benchmark on ResNet-101, as stated by
reported in Table 2(c) of the [Accurate, Large Minibatch SGD][] paper.

Every experiment setting is optimized for Tesla P40 GPUs.

## Result

Experiment | num gpus | kn | learning rate | top-1 error (%) | throughput (samples/sec) | speed up
----------------------- | --------:| ---:|-------------: | ---------------:|-------------------------:|---------:
reference-256 [paper][] | 8 | 256 | 0.1 | 22.08±0.06 | N/A | N/A
reference-8k [paper][] | 256 | 8k | 3.2 | 22.36±0.09 | N/A | N/A
dataparallel-256 | 2 | 256 | 0.1 | 22.02±0.11 | 180.344 | 1.000x
dataparallel-1k | 8 | 1k | 0.4 | 22.04±0.24 | 606.916 | 3.365x
dataparallel-4k | 8 | 4k | 1.6 | OOM | N/A | N/A
pipeline-256 | 2 | 256 | 0.1 | 21.99±0.13 | 117.432 | 0.651x
pipeline-1k | 8 | 1k | 0.4 | 22.24±0.19 | 294.739 | 1.634x
pipeline-4k | 8 | 4k | 1.6 | 22.13±0.09 | 378.746 | 2.100x



## Optimized Environment

- Python 3.6.9
- PyTorch 1.2.0
- CUDA 10.0.130
- 8 Tesla P40 GPUs
- 8+ Intel E5-2650 v4 CPUs

## To Reproduce

First, resolve the dependencies. We highly recommend to use a separate virtual
environment only for this benchmark:

```sh
$ pip install -r requirements.txt
```

Prepare ImageNet dataset at `./data/imagenet`:
```sh
$ python -c "import torchvision; torchvision.datasets.ImageNet('./data/imagenet', split='train', download=True)"
$ python -c "import torchvision; torchvision.datasets.ImageNet('./data/imagenet', split='val', download=True)"
```

Then, run each benchmark:

```sh
$ python main.py naive-128
$ python main.py --devices 0,1 dataparallel # 256
$ python main.py dataparallel # 1k
$ python main.py gpipe-2-256 # gpipie 256
$ python main.py gpipe-8 # gpipie 1k
$ python main.py gpipe-8-4k # gpipie 4k
```


[Accurate, Large Minibatch SGD]: https://arxiv.org/abs/1706.02677
[paper]: https://arxiv.org/abs/1706.02677

0 comments on commit 921f27b

Please sign in to comment.